💡 📢 ☑️ remember to read the readme.md file for helpful hints on the best ways to view/navigate this project
If you visualize this notebook on github you will be missing important content
Some charts/diagrams/features are not visible in github. This is standard and well-known behaviour.
Consider viewing the pre-rendered HTML files, or run all notebooks end to end after enabling the feature flags that control long running operations:
If you chose to run this locally, there are some prerequisites:
python 3.9pip install -r requirements.txt before proceeding.cm-super, dvipng for correct rendering of some LaTeX contentModule 4: Deep Learning - Sprint 3: Practical Deep Learning
===========================================================
Age and Gender classification
-----------------------------
Congratulations on reaching your last project. We will try to put into practice the concepts we learned so far.
In this lesson, we will take two fairly simple problems - gender classification and age classification from an up-close image of a person. But instead of making two different models, your task will be to make one model that does both of these tasks. Moreover, you will then analyze the model from the ethical point of view and see what sort of dangers and caveats such models can have.
The exercise today is to train a multi-objective image classifier using data from https://www.kaggle.com/jangedoo/utkface-new. You will train a single model that can predict gender and age.
Find out more about multi-task learning: https://ruder.io/multi-task, https://www.youtube.com/watch?v=UdXfsAr4Gjw
Concepts to explore
-------------------
- Classification task
- Convolutional neural network
- AI ethics and bias
- Model interpretability
Requirements
------------
- You should go through the standard cycle of EDA-model-evaluation.
- Create a single model that returns age and gender in a single pass
- Analyze model performance
- Understand, which are the best/worst performing samples.
- Use LIME for model interpretability with images. Understand what you model
Once you are done with these tasks, evaluate any ethical issues with this model
- Identify how this model can be biased and check if the results show signs of these issues.
- Analyze bad predictions. Do you see any patterns in misclassified samples, that can cause ethical concerns?
- Describe in which scenarios your model can be biased. Propose solutions to mitigate it.
- Think of a domain, where this model could/could-not be deployed.
Evaluation criteria
-------------------
- EDA
- A single end-to-end trainable deep learning model is built
- Correctly modeled classification/regression
- Correct selection of loss function(s)
- Model interpretability tools used and insights made
- Model aggregate performance
- Quality of ethical concerns raised
- Code quality
from IPython.display import display, Markdown, clear_output, HTML, IFrame
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.colors import ListedColormap
from matplotlib import gridspec
import itertools
import glob
from tqdm import tqdm
import dill
from datetime import datetime
import numba
from tqdm import tqdm
import imagehash
import numpy as np
import pandas as pd
import seaborn as sns
# from fastai.vision.all import
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split, WeightedRandomSampler
from torchvision import transforms, datasets, models
import torch.nn.functional as F
import lime
from lime import lime_image
from torchmetrics.classification import (
MulticlassConfusionMatrix,
MulticlassAccuracy,
MulticlassPrecision,
MulticlassRecall,
MulticlassF1Score,
)
from torchmetrics import ConfusionMatrix
from PIL import Image, ImageDraw, ImageFile
from scipy.stats import chi2_contingency, chisquare, laplace, kstest
from sklearn.metrics import (
precision_score,
recall_score,
f1_score,
roc_curve,
auc,
ConfusionMatrixDisplay,
)
from random import random, seed, shuffle
import logging
import warnings
import os
import shutil
from os import path
from watermark import watermark
from utils import *
from utils import __
loading utils modules... ✅ completed configuring autoreload... ✅ completed
print(watermark())
print(watermark(packages="torch,torchvision,torchmetrics,numpy,pandas,sklearn,scipy"))
print(watermark(conda=True))
Last updated: 2024-02-29T09:30:23.537801+01:00 Python implementation: CPython Python version : 3.9.16 IPython version : 8.10.0 Compiler : GCC 11.2.0 OS : Linux Release : 5.15.0-88-generic Machine : x86_64 Processor : x86_64 CPU cores : 16 Architecture: 64bit torch : 2.1.0 torchvision : 0.16.0 torchmetrics: 1.2.0 numpy : 1.23.5 pandas : 2.1.4 sklearn : 1.2.2 scipy : 1.9.3 conda environment: py39_lab4
seed(100)
pd.options.display.max_rows = 30
pd.options.display.max_colwidth = 50
util.check("done")
✅
Let's use black to auto-format all our cells so they adhere to PEP8
import lab_black
%reload_ext lab_black
util.patch_nb_black()
# fmt: off
# fmt: on
from sklearn import set_config
set_config(transform_output="pandas")
sns.set_theme(context="notebook", style="whitegrid")
plt.rcParams["axes.grid"] = True
moonstone = "#62b6cb"
moonstone_rgb = util.hex_to_rgb(moonstone)
moonstone_rgb_n = np.array(moonstone_rgb) / 255
logger = util.configure_logging(jupyterlab_level=logging.WARN, file_level=logging.DEBUG)
warnings.filterwarnings("ignore", category=FutureWarning)
# import warnings
# warnings.filterwarnings('error', category=pd.errors.DtypeWarning)
# kkalera's logger config
logger = logging.getLogger()
logger.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
file_handler = logging.FileHandler("notebook_logging.log")
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(formatter)
stream_handler = logging.StreamHandler()
stream_handler.setLevel(logging.INFO)
stream_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.addHandler(stream_handler)
def ding(title="Ding!", message="Task completed"):
"""
this method only works on linux.
I'm only using it to notify me when a long running process completes
"""
for i in range(2):
!notify-send '{title}' '{message}'
Let's also create a simple feature toggle that we can use to skip expensive operations during notebook work (to save myself some time!)
Set it to true if you want to run absolutely everything. Set to false to skip optional steps/exploratory work.
def run_entire_notebook(filename: str = None, value_only=False):
run_all_flag = False
if value_only:
return run_all_flag
if not run_all_flag:
print("skipping optional operation")
fullpath = f"cached/printouts/{filename}.txt"
if filename is not None and os.path.exists(fullpath):
print("==== 🗃️ printing cached output ====")
with open(fullpath) as f:
print(f.read())
return run_all_flag
kaggle_dataset_name = "jangedoo/utkface-new"
db_filename = "UTKFace"
auto_kaggle.download_dataset(
kaggle_dataset_name,
db_filename,
timeout_seconds=3 * 60,
is_competition=False,
)
__
Kaggle API 1.5.13 - login as 'edualmas' File [dataset/UTKFace] already exists locally! No need to re-download data for dataset [jangedoo/utkface-new]
Let's take a quick look at the downloaded dataset to see if we see any interesting patterns
# SOURCE: https://susanqq.github.io/UTKFace/
gender_map = {
0: "male",
1: "female",
}
ethnicity_map = {
0: "white",
1: "black",
2: "asian",
3: "indian",
4: "other",
}
Some kaggle datasets have some amount of data duplication.
Let's take a quick look and see if we can save valuable disk space (and training time) by removing all duplicated data.
Let's take the easy steps first: understanding the dataset and detecting/deleting any duplicated files

We see that the folder structure might show some duplication:
We are not 100% sure that there is duplication, but we would not be surprised if there was some duplication going on.
Let's explore a bit more in depth and see what we find. Right now, there are a few scenarions that we could encounter:
Ideal scenario:
If so, we can delete them both!
Plausible (non ideal) scenarios:
Let's use meld, a GUI tool for bulk file/folder comparisons:
Let's use a binary file comparison took (meld) that will be able to compare entire folders..
Initially we see that it says "identical contents" for the nested and non nested folders.
We want to make one last test to make sure. If we modify 1 of the files at random, manually, we see that it does detect the change ✅
If we edit its sibling file to have the same content, it no longer detects the change and goes back to saying "identical contents" (thus also proving that it's inspecting the binary content of the file, and not just the timestamp).

This gives us enough confidence to say that the crop_part1 folders are identical.
Let's do the same test on the other subfolder:

The same applies to the other subfolder (UTKFace vs utkface_aligned_cropped/UTKFace).
We're fairly confident that this evidence allows us to delete half the content of the dataset as duplicated (after performing binary file content comparison)
Saving us from processing >= 33,000 files and 120MB in disk space.
Nothing kills disk performance more than accessing thousands of tiny files!
Let's do one LAST check, just to be extra sure... since making at mistake at this point would result in us incorrectly destroying half the data. A 5 minute check now can save us major problems later.

There seems to be a difference in the inner vs outer "crop" folder.. but we suspect that it's because of the file we tampered with... Let's do one LAST check:

The only diff detected by sha256 is the file we tampered with ✅
We are 100% confident that these folders are definitely duplicates and we can delete them.
We have figured out that the 2 folders (out of 4) were identically duplicated, but we still have not determined what is in the folders that are left.
!ls -l dataset/crop_part1/ | wc --lines
9781
!ls -l dataset/UTKFace/ | wc --lines
23709
It seems they contain a different number of files each. It could be that one of the folders contains a subset of pics from the other one.
We really need to make sure we eliminate all duplicate images to avoid data leakage and artificially inflating the performance of our model.
A quick check with meld seems to show that one of the folders crop_part1 is a subset of the other folder UTKFace... ALMOST:

It seems that there is, at least, 1 file that is not in UTKFace. Let's quickly automate an easy way to get a single copy of each of the unique files. We don't want to do this manually and risk missing something.
We want to create a dict of dict(filehash, filename), where we will store all the files we encounter. This will guarantee that only a single copy of each picture is kept (due to how collisions are "resolved" in dict keys).
We will then be able to copy those files to a new sanitized folder of known-to-be-unique-pictures:
def list_all_files(dataset_folder: str) -> list[str]:
return glob.glob(f"{dataset_folder}/*")
Let's calculate a hash for all the files:
if run_entire_notebook():
!sha512sum dataset/crop_part1/*.jpg > dataset/individual_hashes.txt
!sha512sum dataset/UTKFace/*.jpg >> dataset/individual_hashes.txt
skipping optional operation
!wc -l dataset/individual_hashes.txt
33488 dataset/individual_hashes.txt
!head -n 3 dataset/individual_hashes.txt
14ac5ab1a9d5dbd6243c82c578b413a0a980ae94b2b49f9d37c20b2eab4ec5222389a7f13e5b1ff968f684dba4f77d5475d1b5ab5e39c9361cfc6d570181df98 dataset/crop_part1/10_0_0_20161220222308131.jpg.chip.jpg 374819e76623a8844bc6e086c4336f411a53cee101687e52988928f63ec4b14e4b9a8f498d202c80feac5168fd900f97eacabdb814ef483e53ed0295143102a5 dataset/crop_part1/10_0_0_20170103200329407.jpg.chip.jpg 96efbbd97e6b37d2e4213dd04928f5bec0698d7d0ff347aee82f5e3e6a5c343e215cc56e2d7995a9da33b618b2d0f9b89fb26304c45ee8db208d55b71aae5097 dataset/crop_part1/10_0_0_20170103200522151.jpg.chip.jpg
hashes = pd.read_fwf(
"dataset/individual_hashes.txt",
widths=[128, 2, 1000],
header=None,
)
hashes.columns = ["filehash", "sep", "filename"]
hashes.drop(columns="sep", inplace=True)
duplicated = hashes["filehash"].value_counts()
duplicated = duplicated[duplicated > 1]
# let's only count the duplicates, for our metrics.
# so 1 = 1 duplicate = 2 files with same content
duplicated = duplicated - 1
duplicated
filehash
a23fa03989f3f48067fe5060db21924c456f51fd416bb0a92a4a797beee5c43f3d137318a199554ff33f4e0286a5b15beb78d22dddc41d058f058f25c5ee9f3e 5
706ea2e3da0bcaf39421356dddd35db2c4086108001e9eee5bdcec63479ca6a74738955dec6179f74985ca170d6e4a05e7228b4d03aeae709a91a229831c2378 5
0ef50bddb35c616916ed8133064a2be175add0e3e561fff3758bf9c44aef08ad11310ff25bc09ad7016c065dd64ea73fe90d68a8021148ea7f36de79c6f68170 5
8e0c4ee947b63171f37610628f4edad47fb002c0d6083dcf998726e373871ea00d852e9b7077e81c970f7812febcfd6918ef30026734809aa16794eb237c7d4c 5
952c372684eba466c7c178dd30d1346b67570c9ea2836237ef860f4ce928877feda7972f2791ea7e6ec4d326af105c4cb7d96bdcbd9170971f60244eb41463ab 5
..
b5b8e44204be12c627d4e7ede365eb2599d838d10ae5ca3cd63b6a0a0a8e61e75f84f2f097966098d6499c13bbd5320315270a570402ead4bf6e0e40cefc61d0 1
462c00d7a62686c570735275ef6ebad3a1d96a7454a09b5fbed3ff4b881c8563bc3cc1d5be3af606ae7c59e56e455bb9d12e7992063d32656295c6e293da3080 1
e8377dadc2aa45661dcb4c42001929f5f468ee38e6e43def72a01dcb0380a4cf7d3cda5763c05df5ab38e340027b1be6b4267d1a0927b953da0fdc1eefb1d89c 1
033d3739a24f770cc8330563cf933880d354714f95fc734e2473a0e6e473dccfcb3722bfde881217891cc25a6e58466766e62dd9b427017282f550c26e8f15f5 1
c758e981f4c9207ef3c56d794d5cd1b0fdf24a188eee70198b2f7a13c79ca67ed092f7e4dffb7797a0a2485241536776a7c2c22fdfdf6b2fbb3e91381da638cd 1
Name: count, Length: 9883, dtype: int64
duplicated.sum()
10170
It seems, out of the 33k files, 10k files are duplicated!
This could be a major source of leakage, contributing to getting a false sense of performance, as well as a drain of resources. We will not keep the duplicated files.
Technical note: The chance of random collisions using sha512 is extremely rare ($1 \over 10^{77}$), so we're not even going to bother checking them manually.
dir_clean_1 = "dataset/clean1_removed_identical_hash/"
os.makedirs(dir_clean_1, exist_ok=True)
hashes["new_filename"] = (
hashes["filename"]
.str.replace("dataset/", dir_clean_1)
.str.replace("UTKFace/", "")
.str.replace("crop_part1/", "")
)
files_unique_hashes = hashes.groupby("filehash").first()
files_unique_hashes.head()
| filename | new_filename | |
|---|---|---|
| filehash | ||
| 000c10431583581973efb97a4ab0f8b08850e72d797988b8b8efd44ac78c5c1668ace33dbe7bc87ddff54dc1911034905ad226dea45bf076dfcab08ba8a44162 | dataset/UTKFace/17_1_1_20170114030034621.jpg.c... | dataset/clean1_removed_identical_hash/17_1_1_2... |
| 000edaff14d1bd3d804af056b37a159f0154453164d9ef0f8414c097292affb8134a809fbd6cb0e3990a026d58e746e873e0bd7451f51e076ef78d292a82b1b8 | dataset/UTKFace/35_0_1_20170117121610224.jpg.c... | dataset/clean1_removed_identical_hash/35_0_1_2... |
| 001117a1aeb92e6b8d34e19490472b77a4b5751365bed9d43fa0e16d73ef2e28b43760d33a7524e24cc6e5edbfe0a327b2f8169697e6501982e52a45e6e8601d | dataset/crop_part1/42_0_0_20170104183950934.jp... | dataset/clean1_removed_identical_hash/42_0_0_2... |
| 0018ef8ae5a1e95d1447e0c4a36e1de0362923503490913085840060d6e90ff67c4e51e2a741a0409fa0ae63ea64c73cf561c11d464392d3b12df8a6462eb99a | dataset/UTKFace/31_0_0_20170117181923333.jpg.c... | dataset/clean1_removed_identical_hash/31_0_0_2... |
| 001e33bc5fa66b3caaae5bda44c41de6b7d91aaf66c2b782b0ef11809ed35aadae3fc85582efa15935f404beb49c189c0eb15d63fc782e30a6b92501dca5a3ef | dataset/UTKFace/58_0_1_20170113174947234.jpg.c... | dataset/clean1_removed_identical_hash/58_0_1_2... |
if run_entire_notebook("excluding_files_identical_hash"):
pbar = tqdm(files_unique_hashes.iterrows(), desc="removing files with identical hash")
for index, row in pbar:
shutil.copy(row.filename, row.new_filename)
skipping optional operation ==== 🗃️ printing cached output ==== selecting unique files: 23318it [00:01, 19289.47it/s]
We're done!
The dataset/unique/ directory contains each unique picture, without duplicates.
files_unique_hashes
| filename | new_filename | |
|---|---|---|
| filehash | ||
| 000c10431583581973efb97a4ab0f8b08850e72d797988b8b8efd44ac78c5c1668ace33dbe7bc87ddff54dc1911034905ad226dea45bf076dfcab08ba8a44162 | dataset/UTKFace/17_1_1_20170114030034621.jpg.c... | dataset/clean1_removed_identical_hash/17_1_1_2... |
| 000edaff14d1bd3d804af056b37a159f0154453164d9ef0f8414c097292affb8134a809fbd6cb0e3990a026d58e746e873e0bd7451f51e076ef78d292a82b1b8 | dataset/UTKFace/35_0_1_20170117121610224.jpg.c... | dataset/clean1_removed_identical_hash/35_0_1_2... |
| 001117a1aeb92e6b8d34e19490472b77a4b5751365bed9d43fa0e16d73ef2e28b43760d33a7524e24cc6e5edbfe0a327b2f8169697e6501982e52a45e6e8601d | dataset/crop_part1/42_0_0_20170104183950934.jp... | dataset/clean1_removed_identical_hash/42_0_0_2... |
| 0018ef8ae5a1e95d1447e0c4a36e1de0362923503490913085840060d6e90ff67c4e51e2a741a0409fa0ae63ea64c73cf561c11d464392d3b12df8a6462eb99a | dataset/UTKFace/31_0_0_20170117181923333.jpg.c... | dataset/clean1_removed_identical_hash/31_0_0_2... |
| 001e33bc5fa66b3caaae5bda44c41de6b7d91aaf66c2b782b0ef11809ed35aadae3fc85582efa15935f404beb49c189c0eb15d63fc782e30a6b92501dca5a3ef | dataset/UTKFace/58_0_1_20170113174947234.jpg.c... | dataset/clean1_removed_identical_hash/58_0_1_2... |
| ... | ... | ... |
| fff42adb969d15c96001a1e5bb2f5cfd71ee30c6fadad5a77d71d3ae74e5e7ee9b449bb388c69439cc9a78bf6a7c429a49e6f8b676a1979d865ff6d0046aba64 | dataset/crop_part1/61_0_3_20170109141653583.jp... | dataset/clean1_removed_identical_hash/61_0_3_2... |
| fff4dd576d231e871559d5326323537d89d642d683d54d3b42a9b4e3cf8526d75e175cc2423903e2f951a7d4fdd1ee7b7eae9816c0da92948044e6b5c8b03ed9 | dataset/crop_part1/54_0_0_20170104213004356.jp... | dataset/clean1_removed_identical_hash/54_0_0_2... |
| fff818ed2fc6d24e6d1560a33104df80f031f0508ae7ce161ff0f7ce98670ff99f8cad821db7a981a8eff56700e6b43dd52def6df2c873d870067a89823238f0 | dataset/crop_part1/61_1_0_20170110122324992.jp... | dataset/clean1_removed_identical_hash/61_1_0_2... |
| fff8f71b41b9dd31352b39ae16d6811791b7bf810f10326d527c6aa0806caeabbc710277a0bd5c4d95b60eccd36893dd3b06b07dc487b570df2df0c23096b902 | dataset/UTKFace/20_0_0_20170117140842001.jpg.c... | dataset/clean1_removed_identical_hash/20_0_0_2... |
| fff9e8f3cde1ebb72b715362774665d98e20a6345bcb46576cbdb1af86413a89b8cbb812d7fc33454e9ef0039d5af757f3a08e16c86eae345774ae64b7049405 | dataset/crop_part1/54_1_1_20170110120122138.jp... | dataset/clean1_removed_identical_hash/54_1_1_2... |
23318 rows × 2 columns
Now that we have deleted duplicated (identical) files, the easy part of the work is done.
But we suspect that there might be "similar but not identical" images.
These will be harder to detect because they could be "similar to humans, but not identical". We will use a few algorithms and select the ones we find most optimal for this part of the cleaning.
Check out the sandbox folder, which contains a few notebooks where we tried and tested several algorithms to detect and identify similar images:
In this notebook, we just kept the system that proved best (in terms of results and performance). Some of the algorithms were suboptimal and had a time complexity of $O(n^2)$ which was not ideal when you have 20k images: 20k * 20k = 400 million comparisons.
We had to find other approaches to optimizing the search to be not exponential (linear, ideally).
image_similarity_analysis = (
pd.DataFrame(files_unique_hashes["new_filename"])
.reset_index(drop=True)
.rename(columns={"new_filename": "filename"})
)
image_similarity_analysis
| filename | |
|---|---|
| 0 | dataset/clean1_removed_identical_hash/17_1_1_2... |
| 1 | dataset/clean1_removed_identical_hash/35_0_1_2... |
| 2 | dataset/clean1_removed_identical_hash/42_0_0_2... |
| 3 | dataset/clean1_removed_identical_hash/31_0_0_2... |
| 4 | dataset/clean1_removed_identical_hash/58_0_1_2... |
| ... | ... |
| 23313 | dataset/clean1_removed_identical_hash/61_0_3_2... |
| 23314 | dataset/clean1_removed_identical_hash/54_0_0_2... |
| 23315 | dataset/clean1_removed_identical_hash/61_1_0_2... |
| 23316 | dataset/clean1_removed_identical_hash/20_0_0_2... |
| 23317 | dataset/clean1_removed_identical_hash/54_1_1_2... |
23318 rows × 1 columns
def hash_with_length(length_bytes=8):
def hash_imagefile(filename: str) -> str:
"""
calculates perceptual hash for each file to be able to compare "similar images" using brightness score for each cluster of pixels
cluster size is a grid and is configured using length_bytes.
higher length_bytes results in a grid with more cells (n^2) and slightly longer computation during comparison, but no real benefit
"""
return imagehash.phash(Image.open(filename), hash_size=length_bytes)
return hash_imagefile
def hex_to_int(hash_value: str) -> int:
return int(str(hash_value), base=16)
@cached_dataframe()
def similarity_analysis_with_hash():
image_similarity_analysis["hash_str_8"] = image_similarity_analysis["filename"].apply(
hash_with_length(8)
)
image_similarity_analysis["hash_str_8"] = image_similarity_analysis[
"hash_str_8"
].astype("str")
return image_similarity_analysis
similarity_analysis_with_hash()
Loading from cache [./cached/df/similarity_analysis_with_hash.parquet]
| filename | hash_str_8 | |
|---|---|---|
| 0 | dataset/clean1_removed_identical_hash/17_1_1_2... | 95f5aa5681c2b43e |
| 1 | dataset/clean1_removed_identical_hash/35_0_1_2... | f8dfc300c7c07d32 |
| 2 | dataset/clean1_removed_identical_hash/42_0_0_2... | 91c54cd3c7c3326b |
| 3 | dataset/clean1_removed_identical_hash/31_0_0_2... | 95854a96e3db622d |
| 4 | dataset/clean1_removed_identical_hash/58_0_1_2... | 9dc51e9580d70b76 |
| ... | ... | ... |
| 23313 | dataset/clean1_removed_identical_hash/61_0_3_2... | 95da8549d28b726d |
| 23314 | dataset/clean1_removed_identical_hash/54_0_0_2... | 91ad5ed593a26835 |
| 23315 | dataset/clean1_removed_identical_hash/61_1_0_2... | d0944996c79325df |
| 23316 | dataset/clean1_removed_identical_hash/20_0_0_2... | c6c119c796d7313c |
| 23317 | dataset/clean1_removed_identical_hash/54_1_1_2... | dcd05b8863e24d3b |
23318 rows × 2 columns
Let's see if there are similar-ish pictures (pictures that have an identical Perceptual Hash).
@run
@cached_chart()
def similar_percept_hash():
similar = similarity_analysis_with_hash()
counts = similar["hash_str_8"].value_counts()
counts = counts[counts > 1]
sns.countplot(x=counts[counts > 1], order=counts.value_counts().index)
return plt.gcf()
Loading from cache [./cached/charts/similar_percept_hash.png]
Observations
Outcome
Overall the code below can compute all the comparisons in less than 45 seconds on a cheap laptop CPU, and in 20 seconds on a desktop CPU
Let's remember how Combinations work:
${C_k(n)} = {n\choose k} = {{n!} \over {k!(n-k)!}}$
${C_2(23318)} = {23318\choose 2} = {{23318!} \over {2!(23318-2)!}} = 271852903$
Based on what I can see in here https://www.hackerfactor.com/blog/index.php?/archives/432-Looks-Like-It.html I think we can treat each bit as an individual piece of info and calculate hamming distance by just comparing each bit individually! :)
We will use @numba justintime decorators to try to speed up the expensive computations
@numba.jit(nopython=True)
def hamming_distance(hash1, hash2):
return np.sum(hash1 != hash2)
@numba.jit(nopython=True, parallel=True)
def calculate_distances(i, hashes, threshold):
is_duplicate = np.zeros(len(hashes), dtype=np.bool_)
for j in numba.prange(i + 1, len(hashes)):
hamming_dist = hamming_distance(hashes[i], hashes[j])
if hamming_dist <= threshold:
is_duplicate[j] = True
return np.where(is_duplicate)[0]
def mark_duplicates(df, hash_col, threshold):
"""
lists image pairs that are detected as duplicate
"""
hashes = np.array(
[list(map(int, bin(int(str(h), 16))[2:].zfill(64))) for h in df[hash_col]]
)
duplicates_dict = {}
for i in tqdm(range(len(df))):
duplicate_indices = calculate_distances(i, hashes, threshold)
if len(duplicate_indices) > 0:
duplicates_dict[i] = duplicate_indices
return duplicates_dict
# Just to remember the performance of this algorithm for future benchmarks
# 23k images = 543 million comparisons = 19 seconds!
# Numba and JIT compilation rocks!
@cached_with_pickle(force=run_entire_notebook("mark_duplicates_phash"))
def image_similarity_adj_matrix():
similar = similarity_analysis_with_hash()
duplicates_dict = mark_duplicates(similar, hash_col="hash_str_8", threshold=5)
return duplicates_dict
image_similarity_adj = image_similarity_adj_matrix()
skipping optional operation ==== 🗃️ printing cached output ==== "100%|██████████| 23318/23318 [00:19<00:00, 1167.02it/s] Loading from cache [./cached/pickle/image_similarity_adj_matrix.pickle]
# just a sample, to visualize what this adjacency matrix looks like:
for idx in list(image_similarity_adj.keys())[80:110]:
print(idx, image_similarity_adj[idx])
1169 [18638] 1198 [5356] 1226 [16114] 1231 [12054 17146] 1232 [13587 18548] 1236 [16791] 1247 [14691] 1248 [21699] 1254 [15256] 1256 [3180] 1261 [15627] 1281 [15290] 1312 [ 6256 20831] 1323 [7851] 1331 [13056 18001 18889] 1337 [2977] 1342 [7509 9209] 1356 [ 5219 11216] 1376 [11845] 1385 [13283] 1393 [9222] 1394 [7352] 1403 [2285 2777] 1410 [7755] 1438 [4272] 1447 [20833] 1448 [22896] 1469 [ 9073 19880 21486] 1470 [16086] 1509 [14763 20468]
@cached_dataframe(force=run_entire_notebook())
def duplicates_similarity_df():
similar = similarity_analysis_with_hash()
duplicates_dict = image_similarity_adj_matrix()
all_keys = list(duplicates_dict.keys())
all_values = np.concatenate(list(duplicates_dict.values())).tolist()
all_duplicated_pics_ids = set(all_keys + all_values)
similar["is_similar"] = False
similar.loc[list(all_duplicated_pics_ids), "is_similar"] = True
return similar
image_with_similarity = duplicates_similarity_df()
image_with_similarity
skipping optional operation Loading from cache [./cached/df/duplicates_similarity_df.parquet]
| filename | hash_str_8 | is_similar | |
|---|---|---|---|
| 0 | dataset/clean1_removed_identical_hash/17_1_1_2... | 95f5aa5681c2b43e | True |
| 1 | dataset/clean1_removed_identical_hash/35_0_1_2... | f8dfc300c7c07d32 | False |
| 2 | dataset/clean1_removed_identical_hash/42_0_0_2... | 91c54cd3c7c3326b | False |
| 3 | dataset/clean1_removed_identical_hash/31_0_0_2... | 95854a96e3db622d | False |
| 4 | dataset/clean1_removed_identical_hash/58_0_1_2... | 9dc51e9580d70b76 | False |
| ... | ... | ... | ... |
| 23313 | dataset/clean1_removed_identical_hash/61_0_3_2... | 95da8549d28b726d | False |
| 23314 | dataset/clean1_removed_identical_hash/54_0_0_2... | 91ad5ed593a26835 | False |
| 23315 | dataset/clean1_removed_identical_hash/61_1_0_2... | d0944996c79325df | True |
| 23316 | dataset/clean1_removed_identical_hash/20_0_0_2... | c6c119c796d7313c | True |
| 23317 | dataset/clean1_removed_identical_hash/54_1_1_2... | dcd05b8863e24d3b | False |
23318 rows × 3 columns
image_with_similarity.is_similar.value_counts()
is_similar False 21848 True 1470 Name: count, dtype: int64
def plot_pics(ids: list[int]):
f, ax = plt.subplots(1, len(ids), figsize=(len(ids) * 5, 5))
for i in range(len(ids)):
filename = str(similar.loc[ids[i]]["filename"])
print(filename)
ax[i].imshow(Image.open(filename))
plt.tight_layout()
return plt.gcf()
Just for future reference, these are some of the images used to determine the threshold to use to detect "similar enough" images.
We have extracted the IDs of these images while doing manual tuning of the thresholds.
Observations:
# Threshold of 12
@run
@cached_chart()
def similar_threshold_12_a():
print(
"""
dataset/clean1_removed_identical_hash/15_1_4_20170103230530985.jpg.chip.jpg
dataset/clean1_removed_identical_hash/37_1_0_20170109134008515.jpg.chip.jpg
dataset/clean1_removed_identical_hash/23_1_0_20170116221811019.jpg.chip.jpg
dataset/clean1_removed_identical_hash/30_1_1_20170116012131745.jpg.chip.jpg
dataset/clean1_removed_identical_hash/17_1_0_20170109214021426.jpg.chip.jpg"""
)
return plot_pics([93, 12244, 17302, 17699, 19950])
Loading from cache [./cached/charts/similar_threshold_12_a.png]
There are some similarities, but they are clearly different people
# Threshold of 6
@run
@cached_chart()
def similar_threshold_6_a():
print(
"""
dataset/clean1_removed_identical_hash/32_1_0_20170117154910644.jpg.chip.jpg
dataset/clean1_removed_identical_hash/22_1_3_20170119153416689.jpg.chip.jpg
dataset/clean1_removed_identical_hash/32_1_0_20170117134809503.jpg.chip.jpg"""
)
return plot_pics([7834, 8970, 19856])
Loading from cache [./cached/charts/similar_threshold_6_a.png]
# Threshold of 6 - nope!
@run
@cached_chart()
def similar_threshold_6_b():
print(
"""
dataset/clean1_removed_identical_hash/16_1_0_20170109213504335.jpg.chip.jpg
dataset/clean1_removed_identical_hash/27_1_3_20170104223505487.jpg.chip.jpg
dataset/clean1_removed_identical_hash/24_1_2_20170104234618170.jpg.chip.jpg
dataset/clean1_removed_identical_hash/26_1_3_20170117154940189.jpg.chip.jpg
dataset/clean1_removed_identical_hash/25_1_3_20170117152030871.jpg.chip.jpg
dataset/clean1_removed_identical_hash/26_1_3_20170117174028333.jpg.chip.jpg
dataset/clean1_removed_identical_hash/26_1_3_20170104235421282.jpg.chip.jpg
dataset/clean1_removed_identical_hash/25_1_2_20170104021040316.jpg.chip.jpg
dataset/clean1_removed_identical_hash/25_1_3_20170117152019467.jpg.chip.jpg
"""
)
return plot_pics([5122, 7395, 11127, 11444, 15119, 16152, 18056, 19818, 22147])
Loading from cache [./cached/charts/similar_threshold_6_b.png]
# Threshold of 5
# Good! the 2 clear matches from the previous round still appear at threshold 5!
# nice!
@run
@cached_chart()
def similar_threshold_5_a():
print(
"""
dataset/clean1_removed_identical_hash/25_1_3_20170117152030871.jpg.chip.jpg
dataset/clean1_removed_identical_hash/25_1_3_20170117152019467.jpg.chip.jpg
"""
)
return plot_pics([15119, 22147])
Loading from cache [./cached/charts/similar_threshold_5_a.png]
# Threshold of 5
@run
@cached_chart()
def similar_threshold_5_b():
print(
"""
dataset/clean1_removed_identical_hash/72_0_0_20170111201853033.jpg.chip.jpg
dataset/clean1_removed_identical_hash/65_0_0_20170120225159632.jpg.chip.jpg
dataset/clean1_removed_identical_hash/75_0_0_20170111205238382.jpg.chip.jpg
"""
)
return plot_pics([1044, 5567, 21425])
Loading from cache [./cached/charts/similar_threshold_5_b.png]
# Threshold of 5
@run
@cached_chart()
def similar_threshold_5_c():
print(
"""
dataset/clean1_removed_identical_hash/30_0_1_20170113141654362.jpg.chip.jpg
dataset/clean1_removed_identical_hash/28_0_1_20170103225933161.jpg.chip.jpg
dataset/clean1_removed_identical_hash/32_0_1_20170113001102379.jpg.chip.jpg
"""
)
return plot_pics([15758, 16371, 19337])
Loading from cache [./cached/charts/similar_threshold_5_c.png]
# Threshold of 4
@run
@cached_chart()
def similar_threshold_4_a():
print(
"""
dataset/clean1_removed_identical_hash/23_1_2_20170116173016687.jpg.chip.jpg
dataset/clean1_removed_identical_hash/23_1_2_20170116173145383.jpg.chip.jpg
dataset/clean1_removed_identical_hash/24_0_2_20170116164749805.jpg.chip.jpg
"""
)
return plot_pics([49, 515, 3572])
Loading from cache [./cached/charts/similar_threshold_4_a.png]
Threshold of 3
Looking good!
# threshold of 3
@run
@cached_chart()
def similar_threshold_3_a():
print(
"""
dataset/clean1_removed_identical_hash/1_0_0_20170110213328641.jpg.chip.jpg
dataset/clean1_removed_identical_hash/1_0_0_20161219200139603.jpg.chip.jpg
dataset/clean1_removed_identical_hash/1_0_0_20161219204741557.jpg.chip.jpg
"""
)
return plot_pics([985, 15770, 20737])
Loading from cache [./cached/charts/similar_threshold_3_a.png]
# threshold of 3
@run
@cached_chart()
def similar_threshold_3_b():
print(
"""
dataset/clean1_removed_identical_hash/1_0_0_20170110213328641.jpg.chip.jpg
dataset/clean1_removed_identical_hash/1_0_0_20161219200139603.jpg.chip.jpg
dataset/clean1_removed_identical_hash/1_0_0_20161219204741557.jpg.chip.jpg
"""
)
return plot_pics([11415, 11752, 19133])
Loading from cache [./cached/charts/similar_threshold_3_b.png]
One last safety check to make sure that, despite being marked as similar, they do in fact have different hashes
@run
@cached_chart()
def similar_threshold_3_f():
print(
"""
dataset/clean1_removed_identical_hash/35_0_0_20170117150935786.jpg.chip.jpg
dataset/clean1_removed_identical_hash/35_0_0_20170117170519707.jpg.chip.jpg
dataset/clean1_removed_identical_hash/28_0_0_20170117180626585.jpg.chip.jpg
"""
)
return plot_pics([17241, 18228, 21854])
Loading from cache [./cached/charts/similar_threshold_3_f.png]
!sha256sum dataset/clean1_removed_identical_hash/35_0_0_20170117150935786.jpg.chip.jpg
!sha256sum dataset/clean1_removed_identical_hash/35_0_0_20170117170519707.jpg.chip.jpg
!sha256sum dataset/clean1_removed_identical_hash/28_0_0_20170117180626585.jpg.chip.jpg
1e5fad3db6fe0f7f172d9f5f358ea91f5a3dc7f6904f5b0761c591580a75d412 dataset/clean1_removed_identical_hash/35_0_0_20170117150935786.jpg.chip.jpg 01761854e5833bc963def5bfe0bcb8cc69050c608a6bf0582ecda5a946397674 dataset/clean1_removed_identical_hash/35_0_0_20170117170519707.jpg.chip.jpg e88ff0ceca238030b763cdc585be1e9e119eb22028be59ab3fb835c243db48de dataset/clean1_removed_identical_hash/28_0_0_20170117180626585.jpg.chip.jpg
Observations:
@run
@cached_chart()
def similar_images_perceptual_hash():
return sns.countplot(similar, x="is_similar")
Loading from cache [./cached/charts/similar_images_perceptual_hash.png]
With this new technique, we have removed an additional 5% of the images which were clearly simiar/identical and would have resulted in some type of data leakage if any of the duplicates would have ended up in our test split.
image_with_similarity
| filename | hash_str_8 | is_similar | |
|---|---|---|---|
| 0 | dataset/clean1_removed_identical_hash/17_1_1_2... | 95f5aa5681c2b43e | True |
| 1 | dataset/clean1_removed_identical_hash/35_0_1_2... | f8dfc300c7c07d32 | False |
| 2 | dataset/clean1_removed_identical_hash/42_0_0_2... | 91c54cd3c7c3326b | False |
| 3 | dataset/clean1_removed_identical_hash/31_0_0_2... | 95854a96e3db622d | False |
| 4 | dataset/clean1_removed_identical_hash/58_0_1_2... | 9dc51e9580d70b76 | False |
| ... | ... | ... | ... |
| 23313 | dataset/clean1_removed_identical_hash/61_0_3_2... | 95da8549d28b726d | False |
| 23314 | dataset/clean1_removed_identical_hash/54_0_0_2... | 91ad5ed593a26835 | False |
| 23315 | dataset/clean1_removed_identical_hash/61_1_0_2... | d0944996c79325df | True |
| 23316 | dataset/clean1_removed_identical_hash/20_0_0_2... | c6c119c796d7313c | True |
| 23317 | dataset/clean1_removed_identical_hash/54_1_1_2... | dcd05b8863e24d3b | False |
23318 rows × 3 columns
image_with_similarity["is_similar"].value_counts()
is_similar False 21848 True 1470 Name: count, dtype: int64
without_similar_images = image_with_similarity[~image_with_similarity["is_similar"]]
without_similar_images
| filename | hash_str_8 | is_similar | |
|---|---|---|---|
| 1 | dataset/clean1_removed_identical_hash/35_0_1_2... | f8dfc300c7c07d32 | False |
| 2 | dataset/clean1_removed_identical_hash/42_0_0_2... | 91c54cd3c7c3326b | False |
| 3 | dataset/clean1_removed_identical_hash/31_0_0_2... | 95854a96e3db622d | False |
| 4 | dataset/clean1_removed_identical_hash/58_0_1_2... | 9dc51e9580d70b76 | False |
| 5 | dataset/clean1_removed_identical_hash/9_0_0_20... | 95207e37e017691f | False |
| ... | ... | ... | ... |
| 23311 | dataset/clean1_removed_identical_hash/1_0_2_20... | 818c1e27de536cd9 | False |
| 23312 | dataset/clean1_removed_identical_hash/42_0_0_2... | b904d69792d7286d | False |
| 23313 | dataset/clean1_removed_identical_hash/61_0_3_2... | 95da8549d28b726d | False |
| 23314 | dataset/clean1_removed_identical_hash/54_0_0_2... | 91ad5ed593a26835 | False |
| 23317 | dataset/clean1_removed_identical_hash/54_1_1_2... | dcd05b8863e24d3b | False |
21848 rows × 3 columns
Let's check folder size before deleting the files:
!ls dataset/clean1_removed_identical_hash/ | wc -l
23318
dir_clean_2 = "dataset/clean2_removed_similar_images/"
if run_entire_notebook():
source_dir = dir_clean_1
target_dir = dir_clean_2
os.makedirs(target_dir, exist_ok=True)
for dissimilar_filename in without_similar_images["filename"]:
source = dissimilar_filename
target = source.replace(source_dir, target_dir)
shutil.copy(source, target)
skipping optional operation
def does_file_actually_exist(filename) -> bool:
return os.path.exists(filename)
file_exists = pd.DataFrame(without_similar_images["filename"])
file_exists["exists"] = file_exists["filename"].map(does_file_actually_exist)
file_exists.head()
| filename | exists | |
|---|---|---|
| 1 | dataset/clean1_removed_identical_hash/35_0_1_2... | True |
| 2 | dataset/clean1_removed_identical_hash/42_0_0_2... | True |
| 3 | dataset/clean1_removed_identical_hash/31_0_0_2... | True |
| 4 | dataset/clean1_removed_identical_hash/58_0_1_2... | True |
| 5 | dataset/clean1_removed_identical_hash/9_0_0_20... | True |
file_exists.exists.value_counts()
exists True 21848 Name: count, dtype: int64
!ls dataset/clean1_removed_identical_hash/ | wc -l
23318
!ls dataset/clean2_removed_similar_images/ | wc -l
21848
without_similar_images.head()
| filename | hash_str_8 | is_similar | |
|---|---|---|---|
| 1 | dataset/clean1_removed_identical_hash/35_0_1_2... | f8dfc300c7c07d32 | False |
| 2 | dataset/clean1_removed_identical_hash/42_0_0_2... | 91c54cd3c7c3326b | False |
| 3 | dataset/clean1_removed_identical_hash/31_0_0_2... | 95854a96e3db622d | False |
| 4 | dataset/clean1_removed_identical_hash/58_0_1_2... | 9dc51e9580d70b76 | False |
| 5 | dataset/clean1_removed_identical_hash/9_0_0_20... | 95207e37e017691f | False |
without_similar_images = without_similar_images.copy()
without_similar_images["filename"] = without_similar_images.copy()[
"filename"
].str.replace(dir_clean_1, dir_clean_2)
We wanna have our dataset split into non-overlapping subsets so that we can test how our model performs on unseen data.
def parse_labels_from_filename(dir_name_prefix: str = dir_clean_2):
"""
categorize each of the files and derive labels based on the
portion part of the filename
"""
def _enrich(row):
filename: str = row["filename"]
filename = filename.replace(dir_name_prefix, "")
tokens = filename.split("_")
if len(tokens) != 4:
logger.warning("unable to parse filename %s", filename)
return None
row["age"] = int(tokens[0])
row["gender"] = gender_map[int(tokens[1])]
row["ethnicity"] = ethnicity_map[int(tokens[2])]
return row
return _enrich
def tag_all_pics(df: pd.DataFrame) -> pd.DataFrame:
"""
tags the dataframe using the filename for age, gender and ethnicity
"""
pics = pd.DataFrame(df["filename"]).reset_index(drop=True)
pics.columns = ["filename"]
pics = pics.apply(parse_labels_from_filename(dir_clean_2), axis=1).dropna()
pics["age"] = pics["age"].astype(int)
return pics
@cached_dataframe()
def all_pics_tagged():
return tag_all_pics(without_similar_images)
pics = all_pics_tagged()
pics.head(10)
Loading from cache [./cached/df/all_pics_tagged.parquet]
| filename | age | gender | ethnicity | |
|---|---|---|---|---|
| 0 | dataset/clean2_removed_similar_images/35_0_1_2... | 35 | male | black |
| 1 | dataset/clean2_removed_similar_images/42_0_0_2... | 42 | male | white |
| 2 | dataset/clean2_removed_similar_images/31_0_0_2... | 31 | male | white |
| 3 | dataset/clean2_removed_similar_images/58_0_1_2... | 58 | male | black |
| 4 | dataset/clean2_removed_similar_images/9_0_0_20... | 9 | male | white |
| 5 | dataset/clean2_removed_similar_images/37_0_0_2... | 37 | male | white |
| 6 | dataset/clean2_removed_similar_images/20_1_0_2... | 20 | female | white |
| 7 | dataset/clean2_removed_similar_images/38_1_1_2... | 38 | female | black |
| 8 | dataset/clean2_removed_similar_images/28_1_3_2... | 28 | female | indian |
| 9 | dataset/clean2_removed_similar_images/26_1_2_2... | 26 | female | asian |
Note how the first/default mapping (0, 0) corresponds to ("white", "male").
The project only requires classifying based on gender and age, but we will also keep ethnicity for some parts of the analysis, as it might have impact.
@cached_dataframes()
def pics_splits():
splits = split_utils.split_dataset(
pics,
target_cols=["age", "gender"],
stratify_labels=False,
split_sizes={"train": 0.7, "val": 0.15, "test": 0.15},
)
return {
"train_X": splits["train"][0],
"train_y": splits["train"][1],
"val_X": splits["val"][0],
"val_y": splits["val"][1],
"test_X": splits["test"][0],
"test_y": splits["test"][1],
}
splits = pics_splits()
splits.keys()
Loading from cache [./cached/df_dict/pics_splits.h5]
dict_keys(['test_X', 'test_y', 'train_X', 'train_y', 'val_X', 'val_y'])
We will keep the ethnicity, for future analysis for bias, but it's will not be fed to the model during training. This column is just here for convenience:
splits["train_X"].head()
| filename | ethnicity | |
|---|---|---|
| 1594 | dataset/clean2_removed_similar_images/54_0_0_2... | white |
| 2202 | dataset/clean2_removed_similar_images/14_1_0_2... | white |
| 4384 | dataset/clean2_removed_similar_images/23_0_0_2... | white |
| 701 | dataset/clean2_removed_similar_images/17_0_0_2... | white |
| 16274 | dataset/clean2_removed_similar_images/32_1_0_2... | white |
splits["train_y"].head()
| age | gender | |
|---|---|---|
| 1594 | 54 | male |
| 2202 | 14 | female |
| 4384 | 23 | male |
| 701 | 17 | male |
| 16274 | 32 | female |
os.makedirs("dataset/splits", exist_ok=True)
os.makedirs("dataset/splits/train", exist_ok=True)
os.makedirs("dataset/splits/val", exist_ok=True)
os.makedirs("dataset/splits/test", exist_ok=True)
if run_entire_notebook("splitting_into_folders"):
splits = pics_splits()
for dataset in ["train", "val", "test"]:
target_folder = f"dataset/splits/{dataset}/"
files = pd.DataFrame(splits[f"{dataset}_X"]["filename"])
files["target"] = files["filename"].str.replace(dir_clean_2, target_folder)
print("*" * 20, dataset, "*" * 20)
display(files.head())
for index, row in files.iterrows():
shutil.copy(row.filename, row.target)
skipping optional operation ==== 🗃️ printing cached output ==== filename target 1594 dataset/clean2_removed_similar_images/54_0_0_2... dataset/splits/train/54_0_0_20170117190252594.... 2202 dataset/clean2_removed_similar_images/14_1_0_2... dataset/splits/train/14_1_0_20170109203638205.... 4384 dataset/clean2_removed_similar_images/23_0_0_2... dataset/splits/train/23_0_0_20170114034609023.... 701 dataset/clean2_removed_similar_images/17_0_0_2... dataset/splits/train/17_0_0_20170105183607439.... 16274 dataset/clean2_removed_similar_images/32_1_0_2... dataset/splits/train/32_1_0_20170103182408417.... ******************** val ******************** filename target 17631 dataset/clean2_removed_similar_images/61_0_0_2... dataset/splits/val/61_0_0_20170117174613406.jp... 14808 dataset/clean2_removed_similar_images/22_1_1_2... dataset/splits/val/22_1_1_20170114033301951.jp... 2040 dataset/clean2_removed_similar_images/26_1_1_2... dataset/splits/val/26_1_1_20170116222929223.jp... 16488 dataset/clean2_removed_similar_images/61_0_0_2... dataset/splits/val/61_0_0_20170111222237144.jp... 8732 dataset/clean2_removed_similar_images/58_0_0_2... dataset/splits/val/58_0_0_20170113142246036.jp... ******************** test ******************** filename target 15975 dataset/clean2_removed_similar_images/27_1_0_2... dataset/splits/test/27_1_0_20170117120616194.j... 4956 dataset/clean2_removed_similar_images/32_0_0_2... dataset/splits/test/32_0_0_20170117140353209.j... 11260 dataset/clean2_removed_similar_images/68_1_0_2... dataset/splits/test/68_1_0_20170113210319664.j... 8461 dataset/clean2_removed_similar_images/42_0_0_2... dataset/splits/test/42_0_0_20170109012239137.j... 11413 dataset/clean2_removed_similar_images/26_0_3_2... dataset/splits/test/26_0_3_20170104230323233.j...
def verify_copy_integrity(splitname: str):
"""
checks that the number of files in the folder split
matches the number of files in the dataframe,
to make sure that the copy was correct and
no files were lost or accidentally included
"""
files_in_dir = !ls -l dataset/splits/$splitname/ | wc --lines
files_in_dir = int(files_in_dir[0]) - 1 # gotta skip the header line
files_in_df_split = pics_splits()[f"{splitname}_X"].shape[0]
print(files_in_df_split, " == ", files_in_dir)
assert files_in_df_split == files_in_dir
util.check(files_in_df_split == files_in_dir)
verify_copy_integrity("train")
Loading from cache [./cached/df_dict/pics_splits.h5] 15291 == 15291 ✅
verify_copy_integrity("val")
Loading from cache [./cached/df_dict/pics_splits.h5] 3277 == 3277 ✅
verify_copy_integrity("test")
Loading from cache [./cached/df_dict/pics_splits.h5] 3277 == 3277 ✅
def comparison_across_datasplits(col_name: str):
tr = splits["train_X"].join(splits["train_y"])
v = splits["val_X"].join(splits["val_y"])
tst = splits["test_X"].join(splits["test_y"])
f, ax = plt.subplots(1, 3, figsize=(15, 8))
if col_name == "age":
sns.histplot(data=tr, binwidth=5, y=col_name, ax=ax[0], color=moonstone)
sns.histplot(data=v, binwidth=5, y=col_name, ax=ax[1], color=moonstone)
sns.histplot(data=tst, binwidth=5, y=col_name, ax=ax[2], color=moonstone)
else:
order = gender_map.values() if col_name == "gender" else ethnicity_map.values()
print(order)
sns.countplot(data=tr, y=col_name, ax=ax[0], color=moonstone, order=order)
sns.countplot(data=v, y=col_name, ax=ax[1], color=moonstone, order=order)
sns.countplot(data=tst, y=col_name, ax=ax[2], color=moonstone, order=order)
plt.suptitle(f"{col_name} distribution across data splits")
ax[0].set_title("train split")
ax[1].set_title("val split")
ax[2].set_title("test split")
plt.tight_layout()
return f
@run
@cached_chart()
def split_age_comparison():
return comparison_across_datasplits("age")
Loading from cache [./cached/charts/split_age_comparison.png]
@run
@cached_chart()
def split_gender_comparison():
return comparison_across_datasplits("gender")
Loading from cache [./cached/charts/split_gender_comparison.png]
@run
@cached_chart()
def split_ethnicity_comparison():
return comparison_across_datasplits("ethnicity")
Loading from cache [./cached/charts/split_ethnicity_comparison.png]
Despite not having used hard stratification, the splits seem close enough to each other to be representative enough across the labels we care about (age, gender)
Now that we have ensured:
We're ready to take a look at the data. We will use the train split for EDA.
From now on, we will use the pre-split datasets/pictures from inside the dataset/splits/{split_name} folders
def load_dataset(split_name: str) -> pd.DataFrame:
path = f"dataset/splits/{split_name}"
df = pd.DataFrame(list_all_files(path), columns=["filename"])
return df.apply(parse_labels_from_filename(path + "/"), axis=1).dropna()
@cached_dataframe()
def cached_train_df():
return load_dataset("train")
@cached_dataframe()
def cached_val_df():
return load_dataset("val")
@cached_dataframe()
def cached_test_df():
return load_dataset("test")
df_train = cached_train_df()
df_val = cached_val_df()
df_test = cached_test_df()
Loading from cache [./cached/df/cached_train_df.parquet] Loading from cache [./cached/df/cached_val_df.parquet] Loading from cache [./cached/df/cached_test_df.parquet]
overview = df_train
overview.head()
| filename | age | gender | ethnicity | |
|---|---|---|---|---|
| 0 | dataset/splits/train/26_1_4_20170117154131789.... | 26 | female | other |
| 1 | dataset/splits/train/2_1_4_20161221203029673.j... | 2 | female | other |
| 2 | dataset/splits/train/30_1_0_20170109001620649.... | 30 | female | white |
| 3 | dataset/splits/train/25_0_0_20170120221436173.... | 25 | male | white |
| 4 | dataset/splits/train/10_0_4_20170103202338152.... | 10 | male | other |
@run
@cached_chart()
def dataset_labels():
f, ax = plt.subplots(1, 4, figsize=(15, 4))
sns.histplot(data=overview, binwidth=1, x="age", ax=ax[0], color=moonstone)
sns.histplot(data=overview, binwidth=10, x="age", ax=ax[1], color=moonstone)
sns.histplot(data=overview, x="gender", ax=ax[2], color=moonstone)
sns.histplot(data=overview, x="ethnicity", ax=ax[3], color=moonstone)
ax[0].set_title("age distribution (1yr group)")
ax[1].set_title("age distribution (decade)")
ax[2].set_title("gender distribution")
ax[3].set_title("ethnicity distribution")
plt.tight_layout()
return f
Loading from cache [./cached/charts/dataset_labels.png]
@run
@cached_chart()
def population_breakdown():
ethnicities = overview["ethnicity"].unique()
fig, axs = plt.subplots(2, len(ethnicities), figsize=(3 * len(ethnicities), 13))
for i, ethnicity in enumerate(ethnicity_map.values()):
data = overview[overview["ethnicity"] == ethnicity]
male_df = data[data["gender"] == "male"]
female_df = data[data["gender"] == "female"]
sns.histplot(
data=male_df,
binwidth=5,
y="age",
ax=axs[0, i],
color="grey",
multiple="dodge",
label="male",
)
sns.histplot(
data=female_df,
binwidth=5,
y="age",
ax=axs[0, i],
color="lightgrey",
multiple="dodge",
label="female",
)
sns.histplot(
data=male_df,
binwidth=1,
y="age",
ax=axs[1, i],
color="grey",
multiple="dodge",
label="male",
)
sns.histplot(
data=female_df,
binwidth=1,
y="age",
ax=axs[1, i],
color="lightgrey",
multiple="dodge",
label="female",
)
axs[0, i].legend()
axs[1, i].legend()
axs[0, i].set_title(f"{ethnicity} - 5 yr buckets")
axs[1, i].set_title(f"{ethnicity}")
axs[0, i].set_xlim(0, 550)
axs[1, i].set_xlim(0, 275)
plt.tight_layout()
return fig
Loading from cache [./cached/charts/population_breakdown.png]
Even though the previous "overall aggregates" seemed to present a balanced dataset, when we break it down by gender/age it's clear that each dataset has peculiarities.
A few things that jump out:
Let's build some utility methods to slice and browse our dataset, using any of the 3 dimensions we have: age, gender, ethnicity
def population_filter(
population: pd.DataFrame,
ages: list[int] = None,
genders: list[str] = None,
ethnicities: list[str] = None,
) -> pd.DataFrame:
"""
Retrieves a few samples from the population, which match the criteria specified.
"""
data = population.copy()
if ages:
if not isinstance(ages, list):
ages = [ages]
data = data[data["age"].isin(ages)]
if genders:
if not isinstance(genders, list):
genders = [genders]
data = data[data["gender"].isin(genders)]
if ethnicities:
if not isinstance(ethnicities, list):
ethnicities = [ethnicities]
data = data[data["ethnicity"].isin(ethnicities)]
return data
Now we can easily inspect slices of the dataframe and see how few samples we have of any dice of data.
population_filter(overview, ages=[13, 11], genders="female", ethnicities="indian")
| filename | age | gender | ethnicity | |
|---|---|---|---|---|
| 875 | dataset/splits/train/13_1_3_20170109213029072.... | 13 | female | indian |
| 7066 | dataset/splits/train/11_1_3_20170104223632543.... | 11 | female | indian |
| 13577 | dataset/splits/train/13_1_3_20170117181350659.... | 13 | female | indian |
def population_sample(
population: pd.DataFrame,
ages: list[int] = None,
genders: list[str] = None,
ethnicities: list[str] = None,
) -> plt.Figure:
"""
renders pictures from the population that match the specified criteria (based on the pic labels)
"""
sample = population_filter(
population, ages=ages, genders=genders, ethnicities=ethnicities
)
sample_pics = sample.sample(min(100, len(sample)))["filename"]
cols = 10
rows = int(np.ceil(len(sample_pics) / cols))
f, ax = plt.subplots(rows, cols, figsize=(20, (2.8 * rows)))
for row, col in itertools.product(np.arange(rows), np.arange(cols)):
if (row * cols + col) < len(sample_pics):
filename = sample_pics.iloc[row * cols + col]
image = Image.open(filename)
ax[row, col].imshow(image)
title = filename.split("/")[-1:][0]
title_labels = title[:6]
ts = title[15:-13]
ax[row, col].set_title(title_labels + "..." + ts)
ax[row, col].axis("off")
plt.suptitle(f"{genders = }, {ages = }, {ethnicities = }")
plt.tight_layout()
plt.show()
return f
We can also easily visualize pics for any arbitrary subset:
@run
@cached_chart(extension="jpg")
def pop_sample_black_male():
return population_sample(overview, ages=[50, 51], genders="male", ethnicities="black")
Loading from cache [./cached/charts/pop_sample_black_male.jpg]
@run
@cached_chart(extension="jpg")
def pop_sample_white_female():
return population_sample(
overview, ages=[50, 51, 52], genders="female", ethnicities="white"
)
Loading from cache [./cached/charts/pop_sample_white_female.jpg]
⚠️ For the avid reviewers, you will notice that some pics might be mislabeled.
For exampke, Gene Wilder's picture, appears labeled as "white, woman, 50 years of age".. but that specific picture is a frame from the movie "Willy Wonka & the Chocolate Factory".
Even if we assumed that the person depicted is not Gene Wilder, but his character, both the gender and the age are still incorrect.
The fact that this image is misclassified on both dimensions raises concerns around the quality of the data this dataset contains.

It is well known that Willy Wonka is not a 50 year old woman, but an immortal meme, in this day and age.

@run
@cached_chart(extension="jpg")
def pop_sample_other_male_10_13():
return population_sample(
overview, ages=[10, 11, 12, 13, 14], genders="male", ethnicities="other"
)
Loading from cache [./cached/charts/pop_sample_other_male_10_13.jpg]
In Module 4 - Sprint 1, we used PyTorch Lightning to classify pictures of mushrooms. It was a challenge similar to this one, but there were some things that were not ideal.
Despite Pytorch Lightning being more lightweight than plain Pytorch, it was still quite verbose and boilerplatey. I did not enjoy the experience... and I'm someone who really appreciates having knobs and dials to tune and configure everything...
For this project, I'd like to try out and learn/get to know a different library/framework. Ideally one that is more lightweight and that allows us to easily and quickly try different structures for our NN.
FastAI seems like an ideal candidate. I heard it's more concise and seems to have a cleaner wrapper API that tucks away a lot of the boilerplate code/steps. I'm looking forward to this!
In terms of the approach, I expect this to be quite similar to the previous project (find a pretrained model, customize it to have the needed output layer,... ) but the trick will be to select a loss function that allows us to train that last FC layer.
The requirements talk about the need to compile and train a model that can predict in a single pass.
I consider this to be a poor idea. Normally I'd prefer to have 2 simpler models that can be composed/combined and more importantly, trained separately. But I understand that this is a good requirement to make sure we can practice in-depth customization of models.
Some other approaches that could be done:
An architecture that could be constructed is this:

Most of the low-level operations (flatten, etc...) are handled by, and managed via the framework.
We will likely want to penalize different errors in different ways.
I thought of using an orthogonal vectorial loss function so that the loss was a 2 dimensional vector in space. This could allow the loss to individually tune the error in each of the 2 outputs independently, but was discouraged from doing so after talking to a few STLs from the course. The underlying logic was that, given enough training, having a single loss that aggregates both invidual losses using a simple addition (+) would also work.
I decided to start with that, as a starting point (and to listen to the old YAGNI adage)
from fastai.callback.core import Callback
from torch.utils.tensorboard import SummaryWriter
2024-02-29 09:30:28.505076: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. 2024-02-29 09:30:28.996098: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2024-02-29 09:30:28.996125: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2024-02-29 09:30:29.087464: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered 2024-02-29 09:30:29.251982: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. 2024-02-29 09:30:30.073575: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
from fastai.vision.all import (
Path,
TransformBlock,
DataBlock,
ImageBlock,
FuncSplitter,
get_image_files,
Resize,
aug_transforms,
Module,
resnet34,
resnet50,
resnet101,
ResNet34_Weights,
ResNet50_Weights,
ResNet101_Weights,
create_body,
create_head,
Learner,
FloatTensor,
load_learner,
)
pretrained_configs = {
"small": {
"model": resnet34,
"weights": ResNet34_Weights.IMAGENET1K_V1,
"num_features": 512,
},
"medium": {
"model": resnet50,
"weights": ResNet50_Weights.IMAGENET1K_V1,
"num_features": 2048,
},
"large": {
"model": resnet101,
"weights": ResNet101_Weights.IMAGENET1K_V1,
"num_features": 2048,
},
}
model_size = "small"
pretrained_config = pretrained_configs[model_size]
path = Path("dataset/splits")
def get_labels_age_gender_ethnicity(fname):
labels = fname.name.split("_")
# we need float so calculations
# for age are correct!
age = float(labels[0])
gender = int(labels[1])
ethnicity = int(labels[2])
return age, gender, ethnicity
def get_labels_age_gender(fname):
labels = fname.name.split("_")
age = float(labels[0])
gender = int(labels[1])
return age, gender
class AgeGenderBlock(TransformBlock):
def __init__(
self,
type_tfms=None,
item_tfms=None,
batch_tfms=None,
dl_type=None,
dls_kwargs=None,
):
return super().__init__(
type_tfms=type_tfms,
item_tfms=item_tfms,
batch_tfms=batch_tfms,
dl_type=dl_type,
dls_kwargs=dls_kwargs,
)
def dataset_splitter(filename):
return filename.parent.name in ["val"]
dblock = DataBlock(
blocks=(ImageBlock, AgeGenderBlock()),
splitter=FuncSplitter(dataset_splitter),
get_items=get_image_files,
get_y=get_labels_age_gender,
item_tfms=Resize(460),
batch_tfms=aug_transforms(size=224, min_scale=0.75),
)
age_losses = []
gender_losses = []
total_losses = []
def weighted_loss(
age_normalization_factor: float,
gender_normalization_factor: float,
age_to_gender_weight_coef: float,
):
"""
age_normalization_factor: how to normalize the loss for age
gender_normalization_factor: how to normalize the loss for gender
age_to_gender_weight_coef: how to distribute the weight for the losses, normally 0.5
"""
def shared_loss(pred, targ) -> float:
age_pred, gender_pred = pred
age_targ, gender_targ = targ
# mse seems to be too sensitive to outliers
# age_loss = F.huber_loss(age_pred.squeeze(), age_targ.float())
age_loss = F.mse_loss(age_pred.squeeze(), age_targ.float())
gender_loss = F.cross_entropy(gender_pred, gender_targ.long())
age_loss = age_loss * age_normalization_factor
gender_loss = gender_loss * gender_normalization_factor
# standard formula could be, B×MSE+(1−B)×BCE so we only need 1 param!
w = age_to_gender_weight_coef
loss = (w * age_loss) + (1 - w) * gender_loss
learn.age_loss = age_loss.item()
learn.gender_loss = gender_loss.item()
return loss
return shared_loss
class AgeGenderLossLogger(Callback):
def __init__(self, writer):
self.writer = writer
self.train_iter = 0
self.valid_iter = 0
def after_loss(self):
learn = self.learn
add_scalar = self.writer.add_scalar
if self.training:
i = self.train_iter
add_scalar("Loss/age", learn.age_loss, i)
add_scalar("Loss/gender", learn.gender_loss, i)
add_scalar("Loss/total", learn.loss, i)
self.train_iter += 1
else:
i = self.valid_iter
add_scalar("Valid_Loss/age", learn.age_loss, i)
add_scalar("Valid_Loss/gender", learn.gender_loss, i)
add_scalar("Valid_Loss/total", learn.loss, i)
self.valid_iter += 1
def after_fit(self):
self.writer.close()
dls = dblock.dataloaders(path, num_workers=4)
class AgeGenderModel(Module):
def __init__(self, encoder, n_age_classes, n_gender_classes):
self.encoder = create_body(encoder(weights=pretrained_config["weights"]))
self.age_head = create_head(pretrained_config["num_features"], n_age_classes)
self.gender_head = create_head(
pretrained_config["num_features"], n_gender_classes
)
def forward(self, x):
x = self.encoder(x)
age = self.age_head(x)
gender = self.gender_head(x)
return age, gender
It seems that the unweighted age loss is much larger than the gender loss (obviously!)
- age_loss = tensor(1830.1670 ...
- gender_loss = tensor(0.6796 ...
2024-02-06 17:41:25,186 - root - INFO - w = 0.5
2024-02-06 17:41:25,251 - root - INFO - age_loss = tensor(1830.1670, device='cuda:0', grad_fn=<MseLossBackward0>)
2024-02-06 17:41:25,252 - root - INFO - gender_loss = tensor(0.6796, device='cuda:0', grad_fn=<NllLossBackward0>)
2024-02-06 17:41:25,253 - root - INFO - loss = tensor(915.4233, device='cuda:0', grad_fn=<AddBackward0>)
2024-02-06 17:41:25,464 - root - INFO - w = 0.5
2024-02-06 17:41:25,528 - root - INFO - age_loss = tensor(1113.5537, device='cuda:0', grad_fn=<MseLossBackward0>)
2024-02-06 17:41:25,529 - root - INFO - gender_loss = tensor(0.4847, device='cuda:0', grad_fn=<NllLossBackward0>)
2024-02-06 17:41:25,529 - root - INFO - loss = tensor(557.0192, device='cuda:0', grad_fn=<AddBackward0>)
Let's try to find a weight that results in the weighted gender loss to be on the same order of magnitude as the age loss.
It seems that using a weight of $1 \over 2000$ we achieve a good balance:
- (w * age_loss) = tensor(0.6909 ...
- (1 - w) * gender_loss = tensor(0.6667, ...
2024-02-06 17:43:59,496 - root - INFO - w = 0.0005
2024-02-06 17:43:59,563 - root - INFO - age_loss = tensor(1381.7511, device='cuda:0', grad_fn=<MseLossBackward0>)
2024-02-06 17:43:59,563 - root - INFO - gender_loss = tensor(0.6670, device='cuda:0', grad_fn=<NllLossBackward0>)
2024-02-06 17:43:59,564 - root - INFO - (w * age_loss) = tensor(0.6909, device='cuda:0', grad_fn=<MulBackward0>)
2024-02-06 17:43:59,565 - root - INFO - (1 - w) * gender_loss = tensor(0.6667, device='cuda:0', grad_fn=<MulBackward0>)
2024-02-06 17:43:59,565 - root - INFO - loss = tensor(1.3576, device='cuda:0', grad_fn=<AddBackward0>)
2024-02-06 17:43:59,779 - root - INFO - w = 0.0005
2024-02-06 17:43:59,841 - root - INFO - age_loss = tensor(1527.9042, device='cuda:0', grad_fn=<MseLossBackward0>)
2024-02-06 17:43:59,842 - root - INFO - gender_loss = tensor(0.5323, device='cuda:0', grad_fn=<NllLossBackward0>)
2024-02-06 17:43:59,843 - root - INFO - (w * age_loss) = tensor(0.7640, device='cuda:0', grad_fn=<MulBackward0>)
2024-02-06 17:43:59,843 - root - INFO - (1 - w) * gender_loss = tensor(0.5320, device='cuda:0', grad_fn=<MulBackward0>)
2024-02-06 17:43:59,844 - root - INFO - loss = tensor(1.2960, device='cuda:0', grad_fn=<AddBackward0>)
def new_learner():
writer = SummaryWriter()
model = AgeGenderModel(
pretrained_config["model"], n_age_classes=1, n_gender_classes=2
)
learn = Learner(
dls,
model,
loss_func=weighted_loss(
age_normalization_factor=1 / 10,
gender_normalization_factor=1 / 0.8,
age_to_gender_weight_coef=1 / 2,
),
cbs=[AgeGenderLossLogger(writer)],
)
return learn
learn = new_learner()
if run_entire_notebook("training_model"):
epochs = 30
learn = learn.to_fp16()
learn.fit_one_cycle(epochs, lr_max=0.005)
model = learn.model
timestamp_str = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
exportfilename = f"models/learner_{timestamp_str}_{model_size}_{epochs}epochs.pkl"
last_learner = f"models/learner_chosen_good.pkl"
learn.export(exportfilename, pickle_module=dill)
learn.export(
last_learner, pickle_module=dill
) # last trained model overwrite for convenience
else:
print("\n-----\nloading model from disk")
learn = load_learner("models/learner_chosen_good.pkl", pickle_module=dill, cpu=False)
print("done ✅")
skipping optional operation ==== 🗃️ printing cached output ==== epoch train_loss valid_loss time 0 1.786554 1.695351 00:55 1 1.461248 1.307861 00:53 2 0.647388 0.517804 00:54 3 0.563197 0.525605 00:54 4 0.546874 0.516621 00:54 5 0.515017 0.508244 00:54 6 0.519758 0.462976 00:54 7 0.505805 0.607207 00:54 8 0.486558 0.532631 00:55 9 0.579306 0.621597 00:56 10 0.477896 0.498944 00:56 11 0.532490 0.478929 00:56 12 0.463207 0.431897 00:54 13 0.464652 0.457235 00:56 14 0.452019 0.424648 00:55 15 0.426468 0.458743 00:57 16 0.425189 0.399749 00:56 17 0.422736 0.448454 00:55 18 0.407325 0.400459 00:55 19 0.394068 0.391915 00:54 20 0.391277 0.412320 00:53 21 0.392562 0.392210 00:53 22 0.397992 0.381142 00:53 23 0.375167 0.405484 00:53 24 0.367188 0.378330 00:53 25 0.372667 0.350337 00:53 26 0.376210 0.370057 00:53 27 0.359483 0.369580 00:53 28 0.362384 0.348189 00:53 29 0.346523 0.348245 00:53 30 0.338588 0.361895 00:53 ----- loading model from disk done ✅
Something important to note. Just by enabling mixed precision training using .to_fp16(), we have been able to consistently speed up training by an impressive factor:
This represents a boost in speed of $82 / 132 >= 60% $ faster!
All of this, almost for free, without sacrificing any precision, since this mixed training can benefit from using smaller data types when needed, but uses full sizes (doubles) if the precision requires it.
Let's check the size to see how massive it is:
!ls -hs models/learner_chosen_good.pkl
87M models/learner_chosen_good.pkl
While our model is training, let's check our system resources, settings and config to see how much we're utilizing the GPU:
(py39_lab4) edu@desk:~$ nvidia-smi Tue Feb 6 18:11:46 2024 +---------------------------------------------------------------------------------------+ | NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 | |-----------------------------------------+----------------------+----------------------+ | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+======================+======================| | 0 NVIDIA GeForce RTX 3060 Off | 00000000:01:00.0 On | N/A | | 46% 67C P2 134W / 170W | 11957MiB / 12288MiB | 96% Default | | | | N/A | +-----------------------------------------+----------------------+----------------------+ +---------------------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=======================================================================================| | 0 N/A N/A 1278 G /usr/lib/xorg/Xorg 277MiB | | 0 N/A N/A 1887 G /opt/teamviewer/tv_bin/TeamViewer 2MiB | | 0 N/A N/A 1930 G cinnamon 50MiB | | 0 N/A N/A 27807 G ...GI0Y2VkNTI2Zg%3D%3D&browser=firefox 35MiB | | 0 N/A N/A 50469 G /usr/lib/firefox/firefox 162MiB | | 0 N/A N/A 52900 G /usr/lib/firefox/firefox 7MiB | | 0 N/A N/A 54109 C ...anaconda3/envs/py39_lab4/bin/python 11390MiB | +---------------------------------------------------------------------------------------+ (py39_lab4) edu@desk:~$
We can see that the configuration of the model/workers/batch is able to utilize almost the entirety of the GPU memory (11.1GB out of 12GB).

We did not have to configure anything manually for this to work so well out of the box!
FastAI is so much nicer and enjoyable than using pytorch lightning and having to tweak and configure low level settings just to avoid crashing our GPU.
It's nice seeing a framework that takes care of the low level details with sane and reasonable defaults. My understanding is that, if we ever wanted to, we should be able to get into lower-level code (plain pytorch) to fully tweak minute settings (ideally similar to how seaborn allows us to use a high level API, but still allows for plt low level code to tweak minutia).
We haven't encountered a single occasion that required us to do that 🎉 (so far)

comparing training speed (1 epoch), depending on pretrained model size:
| Size | Model | Time per Epoch | comments |
|---|---|---|---|
| Small | Resnet34 | 1m21 | without optimizations: 2m18 |
| Medium | Resnet50 | 1m56 | |
| Large | Resnet101 | 2m54 | just a bit over the "small" if we configure it without mixed precision optimizations! |
@cached_with_pickle(force=run_entire_notebook())
def learn_losses_train():
return learn.recorder.losses
@cached_with_pickle(force=run_entire_notebook())
def learn_losses_validation():
return learn.recorder.values
skipping optional operation skipping optional operation
train_losses = learn_losses_train()
valid_losses = learn_losses_validation()
Loading from cache [./cached/pickle/learn_losses_train.pickle] Loading from cache [./cached/pickle/learn_losses_validation.pickle]
Losses are not stored as part of the model, so we will store them separately in case we need them later
@run
@cached_chart()
def chart_train_loss():
plot = sns.lineplot([float(loss) for loss in train_losses], label="train loss")
plt.title("Training Loss")
plt.xlabel("batch")
plt.ylabel("loss")
plt.ylim(0)
plt.xlim(0)
return plot
Loading from cache [./cached/charts/chart_train_loss.png]
@run
@cached_chart()
def chart_val_loss():
plot = sns.lineplot([float(loss[0]) for loss in valid_losses], label="val loss")
plt.title("Validation loss")
plt.xlabel("epoch")
plt.ylabel("loss")
plt.ylim(0)
plt.xlim(0)
return plot
Loading from cache [./cached/charts/chart_val_loss.png]
It seems that the 60 epochs config is enough to get a good balance of learning/performance without wasting resources (the chart seems to hint that we're reaching marginal gains after the 60 epochs, with its asymptotic behaviour)
learn.model
AgeGenderModel(
(encoder): Sequential(
(0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(5): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(3): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(6): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(3): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(4): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(5): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(7): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
)
(age_head): Sequential(
(0): AdaptiveConcatPool2d(
(ap): AdaptiveAvgPool2d(output_size=1)
(mp): AdaptiveMaxPool2d(output_size=1)
)
(1): fastai.layers.Flatten(full=False)
(2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): Dropout(p=0.25, inplace=False)
(4): Linear(in_features=1024, out_features=512, bias=False)
(5): ReLU(inplace=True)
(6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): Dropout(p=0.5, inplace=False)
(8): Linear(in_features=512, out_features=1, bias=False)
)
(gender_head): Sequential(
(0): AdaptiveConcatPool2d(
(ap): AdaptiveAvgPool2d(output_size=1)
(mp): AdaptiveMaxPool2d(output_size=1)
)
(1): fastai.layers.Flatten(full=False)
(2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): Dropout(p=0.25, inplace=False)
(4): Linear(in_features=1024, out_features=512, bias=False)
(5): ReLU(inplace=True)
(6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): Dropout(p=0.5, inplace=False)
(8): Linear(in_features=512, out_features=2, bias=False)
)
)
(py39_lab4) edu@desk:~/turing/projects/sprint15-profiling/project$ tensorboard --logdir runs/ Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all TensorBoard 2.15.1 at http://localhost:6006/ (Press CTRL+C to quit)
Tensorboard has been a very useful tool to inspect and debug the model's learning (as well as to detect a bug in my code that resulted in data leakaged)


Context:
The researchers found that overall, the volunteers were incorrect by an average of eight years in their estimates.
Outcome:
Observation:
Outcome:
def performance_for_split(splitname: str):
"""
assess inference performance for the given data split.
Performs inference and collect a summary dataframe for each of
the inputs, with performance metrics
:param splitname: name of the subdir under dataset/splits/{splitname}
"""
tolerance_years = 8
items = get_image_files(path / splitname)
print(len(items))
actuals = [get_labels_age_gender_ethnicity(item) for item in items]
dl = dls.test_dl(items)
preds, _ = learn.get_preds(dl=dl)
display(preds)
performance = pd.DataFrame(
{
"actual_age": [t[0] for t in actuals],
"pred_age": preds[0][:, 0],
"actual_gender": [t[1] for t in actuals],
"pred_gender": torch.argmax(preds[1], dim=1),
"filename": [str(p) for p in items],
"ethnicity": [t[2] for t in actuals],
}
)
performance["age_error"] = performance.pred_age - performance.actual_age
performance["gender_error"] = performance.pred_gender - performance.actual_gender
performance["age_correct"] = performance["age_error"].between(
-tolerance_years, tolerance_years
)
performance["gender_correct"] = performance["gender_error"] == 0
performance.attrs["splitname"] = splitname
return performance
def plot_age_errors(performance_df: pd.DataFrame, ax=None):
if ax is None:
ax = plt.gca()
title = f"age errors - {performance_df.attrs['splitname']} split"
ax.set_title(title)
# ax.set_xlim(-40, 40)
return sns.histplot(
data=performance_df,
x="age_error",
hue="age_correct",
palette=["red", "green"],
# bins=np.arange(-40, 40, 1),
ax=ax,
)
def plot_age_errors_color(
performance_df: pd.DataFrame,
ax=None,
hue="age_correct",
palette=["red", "green"],
):
if ax is None:
ax = plt.gca()
title = f"age errors - {performance_df.attrs['splitname']} split"
ax.set_title(title)
# ax.set_xlim(-40, 40)
return sns.histplot(
data=performance_df,
x="age_error",
hue=hue,
multiple="stack",
palette=palette,
# bins=np.arange(-40, 40, 1),
ax=ax,
)
def plot_age_scatter(performance_df: pd.DataFrame, ax=None, alpha=0.05):
if ax is None:
ax = plt.gca()
title = f"age predictions - {performance_df.attrs['splitname']} split"
ax.set_title(title)
ax.axline((0, 0), (100, 100), color="green")
return sns.scatterplot(
data=performance_df,
y="actual_age",
x="pred_age",
alpha=alpha,
hue="age_correct",
palette=["red", "green"],
ax=ax,
)
def plot_gender_errors(performance_df: pd.DataFrame, ax=None):
if ax is None:
ax = plt.gca()
title = f"gender errors - {performance_df.attrs['splitname']} split"
ax.set_title(title)
return sns.countplot(
data=performance_df,
x="gender_error",
hue="gender_correct",
dodge=False,
palette={True: "green", False: "red"},
ax=ax,
)
performance_train = performance_for_split("train")
performance_train
15291
(tensor([[24.3281],
[ 3.6504],
[29.9688],
...,
[43.6875],
[29.4844],
[23.0156]]),
tensor([[-1.8779, 1.8291],
[ 0.0383, -0.0121],
[-1.8516, 1.8262],
...,
[ 4.3555, -4.6094],
[-1.8867, 1.8916],
[-3.0371, 2.9609]]))
| actual_age | pred_age | actual_gender | pred_gender | filename | ethnicity | age_error | gender_error | age_correct | gender_correct | |
|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 26.0 | 24.328125 | 1 | 1 | dataset/splits/train/26_1_4_20170117154131789.... | 4 | -1.671875 | 0 | True | True |
| 1 | 2.0 | 3.650391 | 1 | 0 | dataset/splits/train/2_1_4_20161221203029673.j... | 4 | 1.650391 | -1 | True | False |
| 2 | 30.0 | 29.968750 | 1 | 1 | dataset/splits/train/30_1_0_20170109001620649.... | 0 | -0.031250 | 0 | True | True |
| 3 | 25.0 | 23.843750 | 0 | 0 | dataset/splits/train/25_0_0_20170120221436173.... | 0 | -1.156250 | 0 | True | True |
| 4 | 10.0 | 11.773438 | 0 | 0 | dataset/splits/train/10_0_4_20170103202338152.... | 4 | 1.773438 | 0 | True | True |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 15286 | 55.0 | 49.843750 | 0 | 0 | dataset/splits/train/55_0_0_20170104184424541.... | 0 | -5.156250 | 0 | True | True |
| 15287 | 43.0 | 41.812500 | 0 | 0 | dataset/splits/train/43_0_3_20170119181404861.... | 3 | -1.187500 | 0 | True | True |
| 15288 | 35.0 | 43.687500 | 0 | 0 | dataset/splits/train/35_0_0_20170104183852983.... | 0 | 8.687500 | 0 | False | True |
| 15289 | 35.0 | 29.484375 | 1 | 1 | dataset/splits/train/35_1_0_20170117144916091.... | 0 | -5.515625 | 0 | True | True |
| 15290 | 18.0 | 23.015625 | 1 | 1 | dataset/splits/train/18_1_0_20170117140343665.... | 0 | 5.015625 | 0 | True | True |
15291 rows × 10 columns
performance_val = performance_for_split("val")
performance_val
3277
(tensor([[67.3125],
[35.7500],
[42.6250],
...,
[30.5469],
[23.5000],
[39.2812]]),
tensor([[ 3.0098, -3.1895],
[-2.5020, 2.4551],
[ 3.0449, -3.2480],
...,
[-1.8887, 1.8955],
[ 2.1738, -2.3848],
[ 4.8281, -5.1016]]))
| actual_age | pred_age | actual_gender | pred_gender | filename | ethnicity | age_error | gender_error | age_correct | gender_correct | |
|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 61.0 | 67.312500 | 0 | 0 | dataset/splits/val/61_0_3_20170119211856632.jp... | 3 | 6.312500 | 0 | True | True |
| 1 | 26.0 | 35.750000 | 1 | 1 | dataset/splits/val/26_1_0_20170117153717556.jp... | 0 | 9.750000 | 0 | False | True |
| 2 | 26.0 | 42.625000 | 0 | 0 | dataset/splits/val/26_0_0_20170117120944631.jp... | 0 | 16.625000 | 0 | False | True |
| 3 | 46.0 | 55.093750 | 1 | 1 | dataset/splits/val/46_1_0_20170104184041597.jp... | 0 | 9.093750 | 0 | False | True |
| 4 | 37.0 | 37.875000 | 0 | 0 | dataset/splits/val/37_0_0_20170119180034627.jp... | 0 | 0.875000 | 0 | True | True |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 3272 | 72.0 | 65.187500 | 0 | 0 | dataset/splits/val/72_0_2_20170105174444334.jp... | 2 | -6.812500 | 0 | True | True |
| 3273 | 35.0 | 40.500000 | 0 | 0 | dataset/splits/val/35_0_0_20170117182852603.jp... | 0 | 5.500000 | 0 | True | True |
| 3274 | 29.0 | 30.546875 | 1 | 1 | dataset/splits/val/29_1_1_20170112204807283.jp... | 1 | 1.546875 | 0 | True | True |
| 3275 | 13.0 | 23.500000 | 0 | 0 | dataset/splits/val/13_0_3_20170110232628896.jp... | 3 | 10.500000 | 0 | False | True |
| 3276 | 32.0 | 39.281250 | 0 | 0 | dataset/splits/val/32_0_0_20170117203115358.jp... | 0 | 7.281250 | 0 | True | True |
3277 rows × 10 columns
performance_test = performance_for_split("test")
performance_test
3277
(tensor([[54.1562],
[46.2188],
[36.7500],
...,
[42.4375],
[12.6719],
[76.3750]]),
tensor([[ 4.3945, -4.6641],
[ 2.9160, -3.0664],
[-1.8604, 1.8027],
...,
[ 3.3223, -3.5234],
[-1.2158, 1.2852],
[ 0.8242, -0.8218]]))
| actual_age | pred_age | actual_gender | pred_gender | filename | ethnicity | age_error | gender_error | age_correct | gender_correct | |
|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 44.0 | 54.156250 | 0 | 0 | dataset/splits/test/44_0_3_20170119204704727.j... | 3 | 10.156250 | 0 | False | True |
| 1 | 48.0 | 46.218750 | 0 | 0 | dataset/splits/test/48_0_0_20170109012109036.j... | 0 | -1.781250 | 0 | True | True |
| 2 | 49.0 | 36.750000 | 1 | 1 | dataset/splits/test/49_1_1_20170113000544753.j... | 1 | -12.250000 | 0 | False | True |
| 3 | 61.0 | 73.500000 | 1 | 1 | dataset/splits/test/61_1_0_20170120225333848.j... | 0 | 12.500000 | 0 | False | True |
| 4 | 1.0 | 2.515625 | 0 | 1 | dataset/splits/test/1_0_0_20161219204552941.jp... | 0 | 1.515625 | 1 | True | False |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 3272 | 23.0 | 25.703125 | 1 | 1 | dataset/splits/test/23_1_0_20170117145019683.j... | 0 | 2.703125 | 0 | True | True |
| 3273 | 12.0 | 21.859375 | 0 | 1 | dataset/splits/test/12_0_0_20170117165940524.j... | 0 | 9.859375 | 1 | False | False |
| 3274 | 46.0 | 42.437500 | 0 | 0 | dataset/splits/test/46_0_0_20170104203049435.j... | 0 | -3.562500 | 0 | True | True |
| 3275 | 29.0 | 12.671875 | 1 | 1 | dataset/splits/test/29_1_1_20170114024736192.j... | 1 | -16.328125 | 0 | False | True |
| 3276 | 75.0 | 76.375000 | 0 | 0 | dataset/splits/test/75_0_3_20170111210912724.j... | 3 | 1.375000 | 0 | True | True |
3277 rows × 10 columns
@run
@cached_chart(force=run_entire_notebook())
def model_performance_all_splits():
f, ax = plt.subplots(3, 3, figsize=(15, 12))
plot_age_errors(performance_train, ax=ax[0, 0])
plot_age_errors(performance_val, ax=ax[1, 0])
plot_age_errors(performance_test, ax=ax[2, 0])
ax[0, 0].set_ylim(0, 2500)
ax[1, 0].set_ylim(0, 500)
ax[2, 0].set_ylim(0, 500)
plot_age_scatter(performance_train, ax=ax[0, 1], alpha=0.03)
plot_age_scatter(performance_val, ax=ax[1, 1], alpha=0.1)
plot_age_scatter(performance_test, ax=ax[2, 1], alpha=0.1)
plot_gender_errors(performance_train, ax=ax[0, 2])
plot_gender_errors(performance_val, ax=ax[1, 2])
plot_gender_errors(performance_test, ax=ax[2, 2])
plt.tight_layout()
return plt.gcf()
skipping optional operation Loading from cache [./cached/charts/model_performance_all_splits.png]
Let's zoom into the age predictions scatterplots
@run
@cached_chart(force=run_entire_notebook())
def model_performance_scatter_age():
f, ax = plt.subplots(1, 2, figsize=(15, 8))
plot_age_scatter(performance_train, ax=ax[0], alpha=0.03)
plot_age_scatter(performance_test, ax=ax[1], alpha=0.1)
plt.tight_layout()
return plt.gcf()
skipping optional operation Loading from cache [./cached/charts/model_performance_scatter_age.png]
Notes on reading these charts:
Some thoughts:
If we use the +/- 8 years range to classify age prediction performance, we get quite notable results.
performance_train["age_correct"].value_counts(normalize=True) * 100
age_correct True 88.692695 False 11.307305 Name: proportion, dtype: float64
performance_val["age_correct"].value_counts(normalize=True) * 100
age_correct True 80.958193 False 19.041807 Name: proportion, dtype: float64
performance_test["age_correct"].value_counts(normalize=True) * 100
age_correct True 87.610619 False 12.389381 Name: proportion, dtype: float64
@run
@cached_chart(force=run_entire_notebook())
def model_confusion_matrix_gender():
ConfusionMatrixDisplay.from_predictions(
performance_test["actual_gender"],
performance_test["pred_gender"],
normalize="true",
cmap="Greys_r",
)
plt.grid(False)
return plt.gcf()
skipping optional operation Loading from cache [./cached/charts/model_confusion_matrix_gender.png]
This is not foolproof because the boundaries of the clusters are arbitrary (and a difference of 0.1 year could be marked as "error" if it happens to fall between 29.96 and 30.05 yrs, for example)... but it's going to be good enough to get a rough idea.
performance_test["actual_decade"] = performance_test["actual_age"] // 10
performance_test["pred_decade"] = performance_test["pred_age"] // 10
@run
@cached_chart(force=run_entire_notebook())
def model_confusion_matrix_age_decades():
plt.figure(figsize=(12, 12))
ConfusionMatrixDisplay.from_predictions(
performance_test["actual_decade"],
performance_test["pred_decade"],
normalize="true",
values_format=",.2f",
cmap="Greys_r",
ax=plt.gca(),
)
plt.grid(False)
return plt.gcf()
skipping optional operation Loading from cache [./cached/charts/model_confusion_matrix_age_decades.png]
Observations:
We prepare and tested with different sizes of pretrained model (resnet34, 50, 101) but in the end, the smaller one gave us fairly good performance, so there was little need to train a more complex one.

In this section we would like to explore a few things:
Let's bring back the original chart showing the split based on each of the attributes
@run
@cached_chart(force=False)
def dataset_labels():
raise NotImplementedError(
"Cached chart not found in cache/ dir.\n"
"Consider re-running the notebook from end to end, "
"or download the entire repository"
)
Loading from cache [./cached/charts/dataset_labels.png]
Recall that, despite some of these appearing balanced at first glance (gender, for example), once we start digging deeper we see that there are clear imbalances across multiple dimensions (white ethnicity has mostly male pics, while asian ethnicity has almost 2x the female pics than male).
@run
@cached_chart(force=False)
def population_breakdown():
raise NotImplementedError(
"Cached chart not found in cache/ dir.\n"
"Consider re-running the notebook from end to end, "
"or download the entire repository"
)
Loading from cache [./cached/charts/population_breakdown.png]
performance_test
| actual_age | pred_age | actual_gender | pred_gender | filename | ethnicity | age_error | gender_error | age_correct | gender_correct | actual_decade | pred_decade | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 44.0 | 54.156250 | 0 | 0 | dataset/splits/test/44_0_3_20170119204704727.j... | 3 | 10.156250 | 0 | False | True | 4.0 | 5.0 |
| 1 | 48.0 | 46.218750 | 0 | 0 | dataset/splits/test/48_0_0_20170109012109036.j... | 0 | -1.781250 | 0 | True | True | 4.0 | 4.0 |
| 2 | 49.0 | 36.750000 | 1 | 1 | dataset/splits/test/49_1_1_20170113000544753.j... | 1 | -12.250000 | 0 | False | True | 4.0 | 3.0 |
| 3 | 61.0 | 73.500000 | 1 | 1 | dataset/splits/test/61_1_0_20170120225333848.j... | 0 | 12.500000 | 0 | False | True | 6.0 | 7.0 |
| 4 | 1.0 | 2.515625 | 0 | 1 | dataset/splits/test/1_0_0_20161219204552941.jp... | 0 | 1.515625 | 1 | True | False | 0.0 | 0.0 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 3272 | 23.0 | 25.703125 | 1 | 1 | dataset/splits/test/23_1_0_20170117145019683.j... | 0 | 2.703125 | 0 | True | True | 2.0 | 2.0 |
| 3273 | 12.0 | 21.859375 | 0 | 1 | dataset/splits/test/12_0_0_20170117165940524.j... | 0 | 9.859375 | 1 | False | False | 1.0 | 2.0 |
| 3274 | 46.0 | 42.437500 | 0 | 0 | dataset/splits/test/46_0_0_20170104203049435.j... | 0 | -3.562500 | 0 | True | True | 4.0 | 4.0 |
| 3275 | 29.0 | 12.671875 | 1 | 1 | dataset/splits/test/29_1_1_20170114024736192.j... | 1 | -16.328125 | 0 | False | True | 2.0 | 1.0 |
| 3276 | 75.0 | 76.375000 | 0 | 0 | dataset/splits/test/75_0_3_20170111210912724.j... | 3 | 1.375000 | 0 | True | True | 7.0 | 7.0 |
3277 rows × 12 columns
@run
@cached_chart(force=run_entire_notebook())
def model_performance_age_scatter_by_ethnicity_and_gender():
f, ax = plt.subplots(2, 5, figsize=(20, 10))
for col, ethnicity in enumerate(ethnicity_map):
e_subset = performance_test[performance_test["ethnicity"] == ethnicity]
for row, gender in enumerate(gender_map):
g_subset = e_subset[e_subset["actual_gender"] == gender]
plot_age_scatter(g_subset, ax=ax[row, col], alpha=0.1)
ax[row, col].set_title(
f"age predictions for {ethnicity_map[ethnicity]}, {gender_map[gender]}"
)
plt.tight_layout()
return plt.gcf()
skipping optional operation Loading from cache [./cached/charts/model_performance_age_scatter_by_ethnicity_and_gender.png]
Observations:
@run
@cached_chart(force=run_entire_notebook())
def model_performance_age_errors_by_ethnicity_and_gender():
f, ax = plt.subplots(2, 5, figsize=(18, 8))
for col, ethnicity in enumerate(ethnicity_map):
e_subset = performance_test[performance_test["ethnicity"] == ethnicity]
for row, gender in enumerate(gender_map):
g_subset = e_subset[e_subset["actual_gender"] == gender]
plot_age_errors(g_subset, ax=ax[row, col])
ax[row, col].set_title(
f"age errors for {ethnicity_map[ethnicity]}, {gender_map[gender]}"
)
# ax[row, col].set_ylim(0, 10)
plt.tight_layout()
return plt.gcf()
skipping optional operation Loading from cache [./cached/charts/model_performance_age_errors_by_ethnicity_and_gender.png]
Observations:
@run
@cached_chart(force=run_entire_notebook(value_only=True))
def model_performance_gender_by_ethnicity():
f, ax = plt.subplots(2, 5, figsize=(20, 6))
for col, ethnicity in enumerate(ethnicity_map):
e_subset = performance_test[performance_test["ethnicity"] == ethnicity]
ConfusionMatrixDisplay.from_predictions(
e_subset["actual_gender"],
e_subset["pred_gender"],
display_labels=gender_map.values(),
cmap="Greys_r",
ax=ax[0, col],
)
ConfusionMatrixDisplay.from_predictions(
e_subset["actual_gender"],
e_subset["pred_gender"],
normalize="true",
values_format=",.2%",
display_labels=gender_map.values(),
cmap="Greys_r",
ax=ax[1, col],
)
ax[0, col].grid(False)
ax[1, col].grid(False)
ax[0, col].set_title(f"gender predictions for {ethnicity_map[ethnicity]}")
ax[1, col].set_title(
f"gender predictions for {ethnicity_map[ethnicity]} - normalized"
)
plt.grid(False)
plt.tight_layout()
return plt.gcf()
Loading from cache [./cached/charts/model_performance_gender_by_ethnicity.png]
top_x_age_errors = 100
top_age_errors = pd.concat(
[
performance_test.sort_values(by="age_error")[:top_x_age_errors],
performance_test.sort_values(by="age_error")[-top_x_age_errors:],
]
)
top_age_errors.head(20)
| actual_age | pred_age | actual_gender | pred_gender | filename | ethnicity | age_error | gender_error | age_correct | gender_correct | actual_decade | pred_decade | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 2545 | 70.0 | 30.875000 | 0 | 0 | dataset/splits/test/70_0_0_20170111200757701.j... | 0 | -39.125000 | 0 | False | True | 7.0 | 3.0 |
| 1929 | 116.0 | 81.500000 | 1 | 1 | dataset/splits/test/116_1_0_20170120134921760.... | 0 | -34.500000 | 0 | False | True | 11.0 | 8.0 |
| 1023 | 77.0 | 47.656250 | 0 | 0 | dataset/splits/test/77_0_1_20170116210256280.j... | 1 | -29.343750 | 0 | False | True | 7.0 | 4.0 |
| 1964 | 57.0 | 31.000000 | 0 | 0 | dataset/splits/test/57_0_0_20170117172532619.j... | 0 | -26.000000 | 0 | False | True | 5.0 | 3.0 |
| 1141 | 116.0 | 92.937500 | 1 | 1 | dataset/splits/test/116_1_2_20170112220255503.... | 2 | -23.062500 | 0 | False | True | 11.0 | 9.0 |
| 3056 | 86.0 | 62.968750 | 1 | 1 | dataset/splits/test/86_1_0_20170120225525242.j... | 0 | -23.031250 | 0 | False | True | 8.0 | 6.0 |
| 1956 | 90.0 | 68.812500 | 0 | 0 | dataset/splits/test/90_0_0_20170120230038954.j... | 0 | -21.187500 | 0 | False | True | 9.0 | 6.0 |
| 1157 | 54.0 | 33.343750 | 1 | 1 | dataset/splits/test/54_1_0_20170117171505517.j... | 0 | -20.656250 | 0 | False | True | 5.0 | 3.0 |
| 1974 | 54.0 | 33.406250 | 0 | 1 | dataset/splits/test/54_0_0_20170113210127075.j... | 0 | -20.593750 | 1 | False | False | 5.0 | 3.0 |
| 2737 | 80.0 | 59.656250 | 0 | 0 | dataset/splits/test/80_0_0_20170117173234032.j... | 0 | -20.343750 | 0 | False | True | 8.0 | 5.0 |
| 912 | 65.0 | 45.968750 | 0 | 0 | dataset/splits/test/65_0_1_20170113145609182.j... | 1 | -19.031250 | 0 | False | True | 6.0 | 4.0 |
| 2685 | 61.0 | 42.000000 | 1 | 1 | dataset/splits/test/61_1_0_20170117174551886.j... | 0 | -19.000000 | 0 | False | True | 6.0 | 4.0 |
| 2302 | 67.0 | 48.218750 | 0 | 0 | dataset/splits/test/67_0_0_20170113210319928.j... | 0 | -18.781250 | 0 | False | True | 6.0 | 4.0 |
| 2404 | 49.0 | 30.468750 | 0 | 0 | dataset/splits/test/49_0_3_20170119205458583.j... | 3 | -18.531250 | 0 | False | True | 4.0 | 3.0 |
| 301 | 50.0 | 31.515625 | 1 | 1 | dataset/splits/test/50_1_0_20170105162633419.j... | 0 | -18.484375 | 0 | False | True | 5.0 | 3.0 |
| 2486 | 58.0 | 39.781250 | 1 | 1 | dataset/splits/test/58_1_1_20170113012224304.j... | 1 | -18.218750 | 0 | False | True | 5.0 | 3.0 |
| 1543 | 70.0 | 51.875000 | 1 | 1 | dataset/splits/test/70_1_0_20170120134300289.j... | 0 | -18.125000 | 0 | False | True | 7.0 | 5.0 |
| 2386 | 89.0 | 71.062500 | 0 | 0 | dataset/splits/test/89_0_1_20170117182437361.j... | 1 | -17.937500 | 0 | False | True | 8.0 | 7.0 |
| 1298 | 51.0 | 33.281250 | 0 | 0 | dataset/splits/test/51_0_1_20170113142040362.j... | 1 | -17.718750 | 0 | False | True | 5.0 | 3.0 |
| 864 | 70.0 | 52.750000 | 1 | 1 | dataset/splits/test/70_1_0_20170117163559185.j... | 0 | -17.250000 | 0 | False | True | 7.0 | 5.0 |
plot_age_errors_color(
top_age_errors.sort_values(by="ethnicity", ascending=False),
hue="ethnicity",
palette="viridis",
)
<AxesSubplot: title={'center': 'age errors - test split'}, xlabel='age_error', ylabel='Count'>
No major differences around predictions, based on ethnicity label
fig, ax = plt.subplots(1, 5, figsize=(16, 4))
for i, (e, ethnicity) in enumerate(ethnicity_map.items()):
cax = ax[i]
top_age_errors_ethnic = top_age_errors[top_age_errors["ethnicity"] == e]
plot_age_scatter(top_age_errors_ethnic, ax=cax, alpha=0.5)
plt.tight_layout()
plt.show()
/tmp/ipykernel_15851/1151474415.py:49: UserWarning: The palette list has more values (2) than needed (1), which may not be intended. return sns.scatterplot( /tmp/ipykernel_15851/1151474415.py:49: UserWarning: The palette list has more values (2) than needed (1), which may not be intended. return sns.scatterplot( /tmp/ipykernel_15851/1151474415.py:49: UserWarning: The palette list has more values (2) than needed (1), which may not be intended. return sns.scatterplot( /tmp/ipykernel_15851/1151474415.py:49: UserWarning: The palette list has more values (2) than needed (1), which may not be intended. return sns.scatterplot( /tmp/ipykernel_15851/1151474415.py:49: UserWarning: The palette list has more values (2) than needed (1), which may not be intended. return sns.scatterplot(
Nothing new compared to what we already saw earlier.
Since LIME is model-agnostic, we need to make sure that we bridge the gap between this library and our custom model.
This means, we must provide the data in the same way that our model transforms and pre-processes during training and prediction.
Recall that our preprocessing pipeline uses no data augmentation (for now) and that it's quite simple: resizing and cropping.
explain_files = glob.glob("dataset/explain/*.jpg")
Let's try with the pics that got the worst errors
age_underestimate = top_age_errors[:5]
age_overestimate = top_age_errors[-5:]
age_underestimate
| actual_age | pred_age | actual_gender | pred_gender | filename | ethnicity | age_error | gender_error | age_correct | gender_correct | actual_decade | pred_decade | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 2545 | 70.0 | 30.87500 | 0 | 0 | dataset/splits/test/70_0_0_20170111200757701.j... | 0 | -39.12500 | 0 | False | True | 7.0 | 3.0 |
| 1929 | 116.0 | 81.50000 | 1 | 1 | dataset/splits/test/116_1_0_20170120134921760.... | 0 | -34.50000 | 0 | False | True | 11.0 | 8.0 |
| 1023 | 77.0 | 47.65625 | 0 | 0 | dataset/splits/test/77_0_1_20170116210256280.j... | 1 | -29.34375 | 0 | False | True | 7.0 | 4.0 |
| 1964 | 57.0 | 31.00000 | 0 | 0 | dataset/splits/test/57_0_0_20170117172532619.j... | 0 | -26.00000 | 0 | False | True | 5.0 | 3.0 |
| 1141 | 116.0 | 92.93750 | 1 | 1 | dataset/splits/test/116_1_2_20170112220255503.... | 2 | -23.06250 | 0 | False | True | 11.0 | 9.0 |
age_overestimate
| actual_age | pred_age | actual_gender | pred_gender | filename | ethnicity | age_error | gender_error | age_correct | gender_correct | actual_decade | pred_decade | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 575 | 27.0 | 44.437500 | 0 | 0 | dataset/splits/test/27_0_2_20170119193329569.j... | 2 | 17.437500 | 0 | False | True | 2.0 | 4.0 |
| 565 | 40.0 | 62.312500 | 0 | 0 | dataset/splits/test/40_0_0_20170113210319647.j... | 0 | 22.312500 | 0 | False | True | 4.0 | 6.0 |
| 397 | 36.0 | 58.531250 | 0 | 0 | dataset/splits/test/36_0_0_20170113210318892.j... | 0 | 22.531250 | 0 | False | True | 3.0 | 5.0 |
| 776 | 1.0 | 25.328125 | 0 | 0 | dataset/splits/test/1_0_3_20170104230640081.jp... | 3 | 24.328125 | 0 | False | True | 0.0 | 2.0 |
| 1237 | 1.0 | 28.125000 | 0 | 1 | dataset/splits/test/1_0_4_20161221193016140.jp... | 4 | 27.125000 | 1 | False | False | 0.0 | 2.0 |
@run
@cached_chart()
def explain_image_underestimate_0_age():
return lime_explain_pics.lime_mask_explain(
learn,
age_underestimate.iloc[0].filename,
"age",
num_samples=100,
)
0%| | 0/100 [00:00<?, ?it/s]
<AxesSubplot: title={'center': 'age 14 to 15 (conf = 0.00)'}>
@run
@cached_chart()
def explain_image_underestimate_1_age():
return lime_explain_pics.lime_mask_explain(
learn,
age_underestimate.iloc[1].filename,
"age",
num_samples=100,
)
0%| | 0/100 [00:00<?, ?it/s]
<AxesSubplot: title={'center': 'age 5 to 6 (conf = 0.00)'}>
@run
@cached_chart()
def explain_image_underestimate_2_age():
return lime_explain_pics.lime_mask_explain(
learn,
age_underestimate.iloc[2].filename,
"age",
num_samples=100,
)
0%| | 0/100 [00:00<?, ?it/s]
<AxesSubplot: title={'center': 'age 5 to 6 (conf = 0.00)'}>
@run
@cached_chart()
def explain_image_underestimate_3_age():
return lime_explain_pics.lime_mask_explain(
learn,
age_underestimate.iloc[3].filename,
"age",
num_samples=100,
)
0%| | 0/100 [00:00<?, ?it/s]
<AxesSubplot: title={'center': 'age 14 to 15 (conf = 0.55)'}>
@run
@cached_chart()
def explain_image_underestimate_4_age():
return lime_explain_pics.lime_mask_explain(
learn,
age_underestimate.iloc[4].filename,
"age",
num_samples=100,
)
0%| | 0/100 [00:00<?, ?it/s]
<AxesSubplot: title={'center': 'age 14 to 15 (conf = 0.00)'}>
Observations:
Outcomes:
@run
@cached_chart()
def explain_image_overestimate_0_age():
return lime_explain_pics.lime_mask_explain(
learn,
age_overestimate.iloc[0].filename,
"age",
num_samples=100,
)
0%| | 0/100 [00:00<?, ?it/s]
<AxesSubplot: title={'center': 'age 5 to 6 (conf = 0.00)'}>
@run
@cached_chart()
def explain_image_overestimate_1_age():
return lime_explain_pics.lime_mask_explain(
learn,
age_overestimate.iloc[1].filename,
"age",
num_samples=100,
)
0%| | 0/100 [00:00<?, ?it/s]
<AxesSubplot: title={'center': 'age 5 to 6 (conf = 0.00)'}>
@run
@cached_chart()
def explain_image_overestimate_2_age():
return lime_explain_pics.lime_mask_explain(
learn,
age_overestimate.iloc[2].filename,
"age",
num_samples=100,
)
0%| | 0/100 [00:00<?, ?it/s]
<AxesSubplot: title={'center': 'age 5 to 6 (conf = 0.00)'}>
@run
@cached_chart()
def explain_image_overestimate_3_age():
return lime_explain_pics.lime_mask_explain(
learn,
age_overestimate.iloc[3].filename,
"age",
num_samples=100,
)
0%| | 0/100 [00:00<?, ?it/s]
<AxesSubplot: title={'center': 'age 14 to 15 (conf = 0.00)'}>
@run
@cached_chart()
def explain_image_overestimate_4_age():
return lime_explain_pics.lime_mask_explain(
learn,
age_overestimate.iloc[4].filename,
"age",
num_samples=100,
)
0%| | 0/100 [00:00<?, ?it/s]
<AxesSubplot: title={'center': 'age 14 to 15 (conf = 0.00)'}>
Observations:
Outcomes:
In this project we have created a custom fast ai model that uses a pretrained neural network to perform compound predictions using computer vision.
The pretrained model is based on ResNet34 and uses the weights from IMAGENET1K_V1.
We decided to use the smallest version of this pretrained model as it performed well enough, but if more performance was required, we expect larger and more complex pretrained models to perform slightly better.
The performance of this model, for age prediction, exceeds standards of average human performance (+/- 8 yrs) for age prediction for the vast majority of cases (over 80% of predictions within this +/-8 yrs range)
The performance of this model, for gender prediction, is almost perfect, achieving over 98% of correct matches in all categories (combinations of gender and ethnicities).
About gender vs gender expression.
This dataset required using computer vision to assess and predict "gender" based on a collection of faces from various people. Since the only data we have about them, the only thing we could actually be assessing is their gender expression. These differences in terminology matter as it helps have a more nuanced conversation about what this dataset actually includes.
Using the Genderbread Person diagram, we can see that we are only able to assess and predict the yellow dials related to gender expression and nothing beyond that.

We see that the model performs better for the majority classes/clusters:
We see that while other classes also have errors, the small amount of data (after slicing based on gender, age buckets, and ethnicity) we end up with similar results and no major problematic areas.
We could have used a more stringent criteria for "correct" guesses around age, allowing a smaller error for your people and a larger error for elder groups but we already see that the predictions are within those tolerances, so we use a simpler criteria to get the point across easily. We don't expect the performance metrics to change significantly if we had used a more triangular range (instead of 2 parallel bars at +/-8 yrs)
This dataset proved to have abundant duplication of data, which required extensive and intensive cleaning to remove duplicated, partially duplicated, as well as "similar-enough" images.
This proved tedious and expensive. The proposed solution is able to run 200 million comparisons in under a minute, but this took a substantial amount of time to plan, prepare and assess the various performance/time tradeoffs.
This was crucial to get right as otherwise the dataset would suffer from extreme data leakage which would result in the model performing well even with little training, just through memorization.
💛💛💛💛💛💛💛💛💛
🤍🤍🤍🤍🤍🤍🤍🤍🤍
💜💜💜💜💜💜💜💜💜
🖤🖤🖤🖤🖤🖤🖤🖤🖤
In this TC project, we were required to use this kaggle dataset: https://www.kaggle.com/datasets/jangedoo/utkface-new
This dataset presents gender as a binary. However, it's crucial to understand that this binary representation is a simplification and doesn't reflect the complexity of gender.
Gender is not binary in any aspect - be it physical, physiological, hormonal, in terms of gender expression, or gender identity.
This binary labeling was used due to the constraints of the project, not as an endorsement of a binary view of gender.
From the perspective of queer theory and LGBT rights, this binary representation can be problematic. It overlooks the experiences and identities of those who don't fit within this binary, including transgender, non-binary, and genderqueer individuals. This can lead to erasure and marginalization, reinforcing harmful stereotypes and discrimination.
Therefore, while our project used a binary gender label, we acknowledge its limitations and advocate for more inclusive and nuanced representations of gender in data. We believe in the importance of recognizing and respecting all gender identities and expressions, as a fundamental aspect of human rights and dignity.
This dataset, with its binary representation of gender, is far from ideal. It inadvertently perpetuates a narrow and oversimplified perception of gender among students. By presenting gender as a binary, it may lead students to internalize this limited understanding, thereby reinforcing the societal norms that queer theory and LGBTQIA+ rights movements challenge. This can hinder the development of a more inclusive, diverse, and accurate understanding of gender. It’s essential to critically engage with such datasets and question the assumptions they make, to foster a more comprehensive and respectful understanding of gender diversity. We must strive for datasets that reflect the reality of human experience, in all its richness and diversity.
The model seemed somewhat brittle to overfitting, showing strange behaviour if it got trained for over 30 epocs:
This appeared as some wildly incorrect predictions (even in the training dataset):

These also appeared in the val dataset.
Interesting to see that this bias affected only pictures of males.

This was fixed by reducing the number of epochs to train for.